#!/usr/bin/env python3
"""
LOAT: Latent-Order Adversarial Training with Multi-View Clustering
Complete implementation with all geometry features, proper causality, and evaluation metrics
"""

import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # Add this
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"


import sys
import random
import json
import math
import time
import copy
import logging
import pickle
import csv
from collections import defaultdict, deque
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Optional, Tuple, Any
from pathlib import Path
import joblib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset, BatchSampler
import torchvision
import torchvision.transforms as transforms
from scipy.stats import entropy as scipy_entropy
from sklearn.metrics import silhouette_score
import pandas as pd
from datetime import datetime
import torchvision.models as models
import itertools
from typing import List, Dict, Optional, Tuple, Any
from scipy.optimize import differential_evolution

# Determinism
SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.use_deterministic_algorithms(True)

# Logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("LOAT")

# Dataset constants
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)



# ========================= Configuration =========================
@dataclass
class Config:
    # Experiment
    experiment_name: str = "loat_cifar10"
    seed: int = 1337
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Dataset
    dataset: str = "cifar10"
    data_root: str = "C:/Users/taiku/Documents/data"  # Your specific path
    batch_size: int = 128
    num_workers: int = 2
    val_split: float = 0.05
    custom_order: Optional[str] = None
    excluded_cluster: Optional[int] = None

    # Model
    model_name: str = "resnet18"
    use_dual_bn: bool = True

    K_final: int = 5

    t_matrix_mode: str = "throughout"  # "throughout", "late", "converged"
    t_matrix_start_epoch: int = 20  # For "late" mode
    t_matrix_convergence_window: int = 5  # For "converged" mode

    # Training
    epochs: int = 30
    warmup_epochs: int = 2
    lr: float = 0.1
    momentum: float = 0.9
    weight_decay: float = 5e-4
    grad_clip_norm: float = 10.0
    lr_schedule: str = "multistep"
    lr_milestones: List[float] = field(default_factory=lambda: [0.5, 0.75])

    # EMA
    ema_enabled: bool = True
    ema_decay: float = 0.999

    # Adversarial
    epsilon: float = 8/255
    pgd_steps: int = 10
    pgd_step_size: float = 2/255
    eval_pgd_steps: int = 20
    eval_pgd_restarts: int = 2
    early_stop_pgd: bool = True
    run_autoattack: bool = False
    autoattack_freq: int = 5
    min_early_stop_steps: int = 3
    adaptive_pgd: bool = True
    order_mode: str = "ucb"
    order_min_edge: float = 0.10
    order_beam: int = 2
    order_update_every: int = 1
    log_paths: bool = True

    # TRADES
    trades_beta: float = 6.0
    beta_warmup_epochs: int = 2

    # Autoencoder
    ae_latent_dim: int = 64
    ae_train_epochs: int = 2
    ae_lr: float = 0.001
    use_denoising: bool = True
    noise_level: float = 0.05
    use_fgsm_noise: bool = True

    # Geometry features
    codebook_size: int = 32
    n_prototypes: int = 8
    n_slices: int = 32
    boe_temperature: float = 0.5  # Now exposed
    use_fft: bool = True
    use_gram: bool = True
    use_topo: bool = False  # G7 topological features
    class_agnostic: bool = False  # If True, stats won't use labels
    use_topo: bool = False
    cluster_feature_type: str = "multi_view"

    # Clustering
    n_clusters: int = 5  # Final clusters
    K_stats: int = 5  # Stats view clusters
    K_geom: int = 5   # Geometry view clusters
    use_multiview: bool = True
    mv_mode: str = "coreg"  # "consensus" or "coreg"
    coreg_alpha: float = 0.2
    coreg_iters: int = 5

    # Discovery
    discovery_interval: int = 3
    cache_embeddings: bool = True  # Cache z_i per epoch

    # Scheduling
    use_ordering: bool = True
    use_cycles: bool = False
    test_cycle_modes: bool = True  # Test natural/reverse/random
    block_size: int = 10
    probe_interval: int = 20  # Lightweight probe every N blocks
    random_batch_ratio: float = 0.10

    # Bandit
    ucb_c: float = 1.5
    warmup_blocks: int = 5

    # Ablations
    ablation_mode: str = "full"  # "full", "stats_only", "geom_only", "single_cluster", "uniform_mix"
    single_cluster_id: int = 0  # For single_cluster ablation

    # Evaluation
    eval_interval: int = 1
    calibrate_bn: bool = True
    calibration_steps: int = 16

    # Output
    output_dir: str = "./experiments_loat"
    log_interval: int = 50

    # LOAT-T Teacher/Student modes
    mode: str = "baseline"  # "teacher", "student", "baseline"
    teacher_epochs: int = 15  # Quick teacher training
    teacher_robust_target: float = 0.40  # Target robustness for teacher
    recipe_path: Optional[str] = None  # Path to teacher recipe
    use_recipe: bool = False  # Use teacher recipe for student

    # Advanced features
    use_adversarial_ae: bool = False  # Adversarial autoencoder
    use_contrastive: bool = False  # Contrastive clustering
    uncertainty_ratio: float = 0.1  # Fraction for uncertain bucket

    save_interval: int = 5  # Save checkpoints
    hitl_enabled: bool = True  # Human-in-the-loop reporting
    pgd_budget_total: Optional[int] = None  # Total PGD calls allowed
    pgd_budget_mode: str = "none"  # "none", "stop", "throttle"

    # SimCLR
    use_simclr: bool = False
    simclr_epochs: int = 50
    simclr_feature_dim: int = 256

    # Teacher profiling
    profile_steps: List[int] = field(default_factory=lambda: [3, 5, 7, 10, 15, 20])
    profile_sample_size: int = 1000

    def __post_init__(self):
        self.experiment_dir = Path(self.output_dir) / f"{self.experiment_name}_{self.ablation_mode}_{time.strftime('%Y%m%d_%H%M%S')}"
        self.experiment_dir.mkdir(parents=True, exist_ok=True)


# ========================= Dual BatchNorm =========================
class DualBatchNorm2d(nn.BatchNorm2d):
    """BatchNorm with separate statistics for clean and adversarial examples"""
    def __init__(self, num_features, **kwargs):
        super().__init__(num_features, **kwargs)
        self.register_buffer('running_mean_adv', torch.zeros(num_features))
        self.register_buffer('running_var_adv', torch.ones(num_features))
        self.register_buffer('num_batches_tracked_adv', torch.tensor(0, dtype=torch.long))
        self.use_adv = False

    def forward(self, x):
        if not self.use_adv:
            return super().forward(x)

        if self.training:
            return F.batch_norm(
                x, self.running_mean_adv, self.running_var_adv,
                self.weight, self.bias,
                True, self.momentum, self.eps
            )
        else:
            return F.batch_norm(
                x, self.running_mean_adv, self.running_var_adv,
                self.weight, self.bias,
                False, 0.0, self.eps
            )


class UseAdvBN:
    """Context manager for switching between clean and adversarial BN"""
    def __init__(self, model: nn.Module, use_adv: bool):
        self.model = model
        self.use_adv = use_adv
        self._targets = []
        self._prev = []

    def __enter__(self):
        with torch.no_grad():
            for m in self.model.modules():
                if isinstance(m, DualBatchNorm2d):
                    self._targets.append(m)
                    self._prev.append(m.use_adv)
                    m.use_adv = self.use_adv
        return self

    def __exit__(self, exc_type, exc, tb):
        with torch.no_grad():
            for m, prev in zip(self._targets, self._prev):
                m.use_adv = prev
        return False


# ========================= Attacks =========================
class Attacks:
    @staticmethod
    def normalize(x: torch.Tensor, mean, std) -> torch.Tensor:
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean, device=x.device).view(1, 3, 1, 1)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std, device=x.device).view(1, 3, 1, 1)
        return (x - mean) / std

    @staticmethod
    def fgsm(model, x, y, eps, mean, std, use_adv_bn=True):
        was_training = model.training
        model.eval()

        x_adv = x.detach().clone().requires_grad_(True)
        model.zero_grad(set_to_none=True)

        with UseAdvBN(model, use_adv_bn), torch.enable_grad():
            logits = model(Attacks.normalize(x_adv, mean, std))
            loss = F.cross_entropy(logits, y)

        grad = torch.autograd.grad(loss, x_adv)[0]
        x_adv = (x + eps * grad.sign()).clamp(0, 1).detach()

        if was_training:
            model.train()
        return x_adv

    @staticmethod
    def get_adaptive_margin_threshold(epoch, max_epochs):
        """Anneal margin threshold from 0.8 to 0.3 over training"""
        progress = min(1.0, epoch / max(1, max_epochs))
        return 0.8 - 0.5 * progress


    @staticmethod
    def pgd_adaptive(model, x, y, eps, step_size, steps, mean, std,
                    random_start=True, use_adv_bn=True,
                    cluster_difficulties=None, cluster_id=None,
                    difficulty_profiles=None,
                    margin_threshold=0.5,
                    current_epoch=0, max_epochs=30,
                    min_early_stop_steps=3,
                    early_stop=True,
                    is_student_mode=False):  # NEW PARAMETER
        """
        Adaptive PGD with dynamic step size based on gradient landscape
        """

        if cluster_id is not None and cluster_difficulties:
            diff = cluster_difficulties.get(cluster_id, -1)
            logger.debug(f"pgd_adaptive: cluster={cluster_id}, diff={diff:.3f}, steps={steps}")

        was_training = model.training
        model.eval()

        B = x.size(0)
        device = x.device

        # Initialize
        if random_start:
            x_adv = x + torch.empty_like(x).uniform_(-eps, eps)
            x_adv = x_adv.clamp(0, 1)
        else:
            x_adv = x.clone()

        # Determine budget based on cluster difficulty
        sample_steps = steps  # Default to full budget
        dynamic_step_size = step_size

        # CRITICAL FIX: Student mode bypasses warmup entirely
        if is_student_mode:
            bypass_warmup = True
            logger.debug("Student: bypassing warmup, using difficulty immediately")
        elif current_epoch < 2:
            bypass_warmup = False
            sample_steps = steps
            dynamic_step_size = step_size
            logger.debug(f"Teacher warmup epoch {current_epoch}: using full {sample_steps} steps")
        else:
            bypass_warmup = True

        # Apply difficulty profiles if past warmup
        if bypass_warmup and difficulty_profiles and cluster_id in difficulty_profiles:
            profile = difficulty_profiles[cluster_id]
            diff = profile.get('overall_difficulty', 0.5)

            # More aggressive adaptation for students
            if is_student_mode:
                adaptation_factor = 0.7  # More aggressive than teacher
            else:
                adaptation_factor = 0.8

            if diff >= 0.80:
                sample_steps = steps                 # 10
            elif diff >= 0.60:
                sample_steps = max(min(int(steps * adaptation_factor), steps), min_early_stop_steps)  # 7 for student, 8 for teacher
            else:
                sample_steps = max(min(int(steps * (adaptation_factor - 0.1)), steps), min_early_stop_steps)  # 5-6 for student, 6-7 for teacher

            logger.debug(f"Cluster {cluster_id}: diff={diff:.3f}, using {sample_steps}/{steps} steps (student={is_student_mode})")

        steps_used = torch.zeros(B, dtype=torch.long, device=device)
        active = torch.ones(B, dtype=torch.bool, device=device)
        prev_loss = None

        with UseAdvBN(model, use_adv_bn):
            for t in range(sample_steps):
                if not active.any():
                    break

                x_adv = x_adv.detach().requires_grad_(True)

                with torch.enable_grad():
                    logits = model(Attacks.normalize(x_adv, mean, std))
                    loss = F.cross_entropy(logits, y, reduction='none')

                    # Adaptive step size based on loss landscape
                    if prev_loss is not None:
                        loss_change = (loss - prev_loss).abs()
                        # Reduce step size where loss is changing rapidly (sharp landscape)
                        step_multiplier = torch.where(
                            loss_change > loss.mean(),
                            torch.tensor(0.7, device=device),
                            torch.tensor(1.3, device=device)
                        )
                    else:
                        step_multiplier = torch.ones(B, device=device)

                    prev_loss = loss.detach()

                    # Compute gradient
                    grad = torch.autograd.grad(loss.sum(), x_adv)[0]

                    # Adaptive step per sample
                    grad_norms = grad.view(B, -1).norm(dim=1, keepdim=True).view(B, 1, 1, 1)
                    normalized_grad = grad / (grad_norms + 1e-8)

                    # Dynamic step size per sample
                    sample_step = dynamic_step_size * step_multiplier.view(B, 1, 1, 1)

                with torch.no_grad():
                    # Update with adaptive steps
                    x_adv = x_adv + sample_step * normalized_grad.sign()
                    x_adv = torch.max(torch.min(x_adv, x + eps), x - eps).clamp(0, 1)

                    steps_used[active] += 1

                    # Compute adaptive margin threshold
                    adaptive_margin = Attacks.get_adaptive_margin_threshold(current_epoch, max_epochs)

                    # Early stopping check: honor the configured minimum
                    if early_stop and (t + 1) >= min_early_stop_steps:
                        with torch.no_grad():
                            pred = model(Attacks.normalize(x_adv, mean, std)).argmax(1)
                            misclassified = (pred != y)

                        # If most samples are already fooled, stop
                        success_threshold = 0.7 if is_student_mode else 0.9  # Changed from 0.8 to 0.7
                        if misclassified.float().mean() >= success_threshold:
                            logger.debug(f"Early stop at step {steps_used}: {misclassified.sum()}/{len(y)} misclassified")
                            break

        if was_training:
            model.train()

        mean_steps = steps_used.float().mean().item()
        return x_adv.detach(), mean_steps

    @staticmethod
    def pgd(model, x, y, eps, step_size, steps, mean, std,
            random_start=True, use_adv_bn=True, early_stop=False, min_early_stop_steps=0):


        """PGD attack with early stopping and step counting"""
        was_training = model.training
        model.eval()

        if random_start:
            x_adv = x + torch.empty_like(x).uniform_(-eps, eps)
            x_adv = x_adv.clamp(0, 1)
        else:
            x_adv = x.clone()

        steps_used = 0

        with UseAdvBN(model, use_adv_bn):
            for t in range(steps):
                x_adv = x_adv.detach().requires_grad_(True)

                # ensure grads are tracked even if caller is in no_grad
                with torch.enable_grad():
                    logits = model(Attacks.normalize(x_adv, mean, std))
                    loss = F.cross_entropy(logits, y)
                    grad = torch.autograd.grad(loss, x_adv, retain_graph=False, create_graph=False)[0]

                with torch.no_grad():
                    x_adv = x_adv + step_size * grad.sign()
                    x_adv = torch.max(torch.min(x_adv, x + eps), x - eps).clamp(0, 1)

                steps_used += 1

                if early_stop and (t + 1) >= min_early_stop_steps:
                    with torch.no_grad():
                        pred = model(Attacks.normalize(x_adv, mean, std)).argmax(1)
                    if (pred != y).all():    # success on all items
                        break


        if was_training:
            model.train()

        return x_adv.detach(), steps_used


# ========================= Autoencoder =========================
class DenoisingAutoencoder(nn.Module):
    """Lightweight denoising autoencoder for geometry extraction"""
    def __init__(self, latent_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )

        self.fc_up = nn.Linear(latent_dim, 128 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        h = self.fc_up(z).view(z.size(0), 128, 4, 4)
        return self.decoder(h)

    def forward(self, x):
        z = self.encode(x)
        x_recon = self.decode(z)
        return x_recon, z




# ========================= Geometry Profiler =========================
class GeometryProfiler:
    """Extract geometric features from batches (G1-G7)"""

    def __init__(self, cfg: Config, device):
        self.cfg = cfg
        self.device = device
        self.codebook = None
        self.prototypes = None
        self.mean = torch.tensor(CIFAR10_MEAN).view(1, 3, 1, 1).to(device)
        self.std = torch.tensor(CIFAR10_STD).view(1, 3, 1, 1).to(device)

        # Cache for embeddings
        self.embedding_cache = {}
        self.cache_epoch = -1

    def fit_codebook(self, Z: torch.Tensor):
        M = self.cfg.codebook_size
        """Fit codebook for Bag-of-Embeddings using k-means"""
        Z = Z.float()
        N, d = Z.shape
        #M = min(self.cfg.codebook_size, N)

        idx = torch.randperm(N)[:M]
        C = Z[idx].clone()

        for _ in range(20):
            d2 = ((Z[:, None, :] - C[None, :, :]) ** 2).sum(-1)
            W = F.softmax(-d2 / self.cfg.boe_temperature, dim=1)
            num = (W[:, :, None] * Z[:, None, :]).sum(0)
            den = W.sum(0)[:, None] + 1e-8
            C = num / den

        self.codebook = C.to(self.device)
        logger.info(f"Codebook fitted with {M} clusters")

    def fit_prototypes(self, Z: torch.Tensor):
        """Fit prototypes for OT signatures"""
        Z = Z.float()
        N, d = Z.shape
        R = min(self.cfg.n_prototypes, N)

        idx = torch.randperm(N)[:R]
        P = Z[idx].clone()

        for _ in range(10):
            d2 = ((Z[:, None, :] - P[None, :, :]) ** 2).sum(-1)
            W = F.softmax(-d2, dim=1)
            num = (W[:, :, None] * Z[:, None, :]).sum(0)
            den = W.sum(0)[:, None] + 1e-8
            P = num / den

        self.prototypes = P.to(self.device)
        logger.info(f"Prototypes fitted with {R} anchors")

    def _sliced_w1(self, Z: torch.Tensor) -> torch.Tensor:
        """Compute sliced Wasserstein-1 distances to prototypes"""
        if self.prototypes is None:
            return torch.zeros(1, device=Z.device)

        n, d = Z.shape
        assert self.prototypes.shape[1] == d, f"Prototype dim {self.prototypes.shape[1]} != Z dim {d}"
        R = self.prototypes.shape[0]

        Z_centered = Z - Z.mean(0, keepdim=True)
        P_centered = self.prototypes - self.prototypes.mean(0, keepdim=True)

        L = self.cfg.n_slices
        u = torch.randn(L, d, device=Z.device)
        u = u / (u.norm(dim=1, keepdim=True) + 1e-8)

        z_proj = Z_centered @ u.t()  # (n, L)
        p_proj = P_centered @ u.t()  # (R, L)

        dists = torch.zeros(R, device=Z.device)

        # Compute true 1D Wasserstein-1 using sorted samples
        for r in range(R):
            w1_sum = 0
            for l in range(L):
                # Sort projections
                z_sorted = torch.sort(z_proj[:, l])[0]
                p_val = p_proj[r, l]
                # For single prototype point, compute W1 as mean absolute difference
                diffs = (z_sorted - p_val).abs()
                w1_sum += diffs.mean()
            dists[r] = w1_sum / L

        return dists

    def _compute_fft_features(self, x_batch: torch.Tensor) -> torch.Tensor:
        """G5: Real FFT features on images - optimized version"""
        batch_size = x_batch.size(0)
        device = x_batch.device

        # Process all images at once for FFT
        x_gray = x_batch.mean(1)  # (B, H, W) - grayscale
        fft_batch = torch.fft.fft2(x_gray)  # (B, H, W)
        spectrum_batch = torch.fft.fftshift(torch.abs(fft_batch))  # (B, H, W)

        h, w = spectrum_batch.shape[1:]
        cy, cx = h // 2, w // 2

        # Create radial distance matrix once (shared for all images)
        y_grid, x_grid = torch.meshgrid(
            torch.arange(h, device=device),
            torch.arange(w, device=device),
            indexing='ij'
        )
        r = torch.sqrt((x_grid - cx).float()**2 + (y_grid - cy).float()**2)

        # Create masks once
        low_mask = r < 5
        mid_mask = (r >= 5) & (r < 10)
        high_mask = r >= 10

        # Batch compute band energies
        low_energy = spectrum_batch[:, low_mask].mean(dim=1)  # (B,)
        mid_energy = spectrum_batch[:, mid_mask].mean(dim=1)  # (B,)
        high_energy = spectrum_batch[:, high_mask].mean(dim=1)  # (B,)
        low_high_ratio = low_energy / (high_energy + 1e-8)  # (B,)

        # Orientation entropy - still need per-sample but optimized
        theta = torch.atan2((y_grid - cy).float(), (x_grid - cx).float())
        n_bins = 8
        theta_bins = torch.linspace(-np.pi, np.pi, n_bins + 1, device=device)

        orientation_entropies = []
        for i in range(batch_size):
            spectrum_i = spectrum_batch[i]
            orientation_hist = torch.zeros(n_bins, device=device)

            for j in range(n_bins):
                mask = (theta >= theta_bins[j]) & (theta < theta_bins[j+1])
                if mask.any():
                    orientation_hist[j] = spectrum_i[mask].mean()

            # Normalize and compute entropy
            orientation_hist = orientation_hist / (orientation_hist.sum() + 1e-8)
            # Avoid log(0)
            orientation_hist_safe = torch.clamp(orientation_hist, min=1e-10)
            orientation_entropy = -(orientation_hist_safe * torch.log(orientation_hist_safe)).sum()
            orientation_entropies.append(orientation_entropy)

        orientation_entropies = torch.stack(orientation_entropies)  # (B,)

        # Stack all features
        features_stack = torch.stack([
            low_energy, mid_energy, high_energy,
            low_high_ratio, orientation_entropies
        ], dim=1)  # (B, 5)

        # Return mean and std
        mean_features = features_stack.mean(0)
        std_features = features_stack.std(0)

        # Handle case where batch_size=1 (std would be nan)
        if batch_size == 1:
            std_features = torch.zeros_like(mean_features)

        return torch.cat([mean_features, std_features])

    def _compute_gram_features(self, model: nn.Module, x: torch.Tensor) -> torch.Tensor:
        """
        Extract a mid-layer feature map and return Gram eigen-features
        as a compact style/texture signature. x: (B=1,3,32,32)
        """
        model.eval()
        x = x.to(self.device)
        with torch.no_grad(), UseAdvBN(model, False):
            h = model.relu(model.bn1(model.conv1(Attacks.normalize(x, self.mean, self.std).to(x.device))))
            h = model.layer1(h); h = model.layer2(h)
            features = model.layer3(h)               # (1, C, H, W)
        c, h_dim, w_dim = features.shape[1:]
        features_flat = features[0].view(c, -1)     # squeeze batch dim
        G = features_flat @ features_flat.t()       # (C,C)
        # eigenvalues (top-k) + trace + off-diag energy
        try:
            evals = torch.linalg.eigvalsh(G)            # (C,)
        except:
            evals = torch.zeros(c, device=G.device)
        k = min(8, evals.numel())
        topk = torch.topk(evals, k).values.real
        trace = torch.trace(G).real.unsqueeze(0)
        offdiag = (G - torch.diag(torch.diagonal(G))).pow(2).sum().sqrt().real.unsqueeze(0)
        return torch.cat([topk, trace, offdiag], dim=0).detach()  # Remove .cpu()


    def batch_geometry(self, z_batch, x_batch=None, model=None):
        n, d = z_batch.shape

        if self.codebook is not None:
            if self.codebook.shape[1] != d:
                logger.warning(f"Codebook dim {self.codebook.shape[1]} != z_batch dim {d}, refitting codebook")
                self.fit_codebook(z_batch)

        features = []

        # G1: Set pooling statistics
        mu = z_batch.mean(0)
        sd = z_batch.std(0)
        energy = z_batch.norm(dim=1).mean()

        mu_expanded = mu.unsqueeze(0).expand_as(z_batch)
        cosine_sim = F.cosine_similarity(z_batch, mu_expanded, dim=1)
        dispersion = (1 - cosine_sim).mean()

        features.extend([mu, sd, energy.unsqueeze(0), dispersion.unsqueeze(0)])

        # G2: Bag-of-Embeddings
        if self.codebook is not None:
            d2 = ((z_batch[:, None, :] - self.codebook[None, :, :]) ** 2).sum(-1)
            W = F.softmax(-d2 / self.cfg.boe_temperature, dim=1)
            histogram = W.mean(0)
            features.append(histogram)

        # G3: Low-rank covariance spectrum
        z_centered = z_batch - mu
        try:
            U, S, V = torch.pca_lowrank(z_centered, q=min(8, d-1))
            top_eigs = S[:8]
            if len(top_eigs) < 8:
                top_eigs = F.pad(top_eigs, (0, 8 - len(top_eigs)))
        except:
            top_eigs = torch.zeros(8, device=z_batch.device)

        cov_trace = (z_centered.t() @ z_centered).trace() / n
        features.extend([top_eigs, cov_trace.unsqueeze(0)])

        # G4: OT/Transport signatures (W1)
        if self.prototypes is not None:
            ot_dists = self._sliced_w1(z_batch)
            features.append(ot_dists)

        # G5: FFT features (if enabled and images provided)
        if self.cfg.use_fft and x_batch is not None:
            fft_features = self._compute_fft_features(x_batch)
            features.append(fft_features)

        # G6: Gram features (if enabled and model provided)
        if self.cfg.use_gram and model is not None:
            gram_features = self._compute_gram_features(model, x_batch)
            features.append(gram_features)

        # G7: Topological features (placeholder - would implement persistence diagrams)
        if self.cfg.use_topo:
            # Simple proxy: variance of pairwise distances
            pdist = torch.cdist(z_batch, z_batch)
            topo_features = torch.tensor([
                pdist.mean(),
                pdist.std(),
                pdist.max(),
                pdist.min()
            ], device=z_batch.device)
            features.append(topo_features)

        return torch.cat([f.flatten() for f in features]).detach()


# ========================= Stats Profiler =========================
class StatsProfiler:
    """Extract statistical features from batches"""

    def __init__(self, device, cfg):
        self.device = device
        self.cfg = cfg
        self.mean = torch.tensor(CIFAR10_MEAN).view(1, 3, 1, 1).to(device)
        self.std = torch.tensor(CIFAR10_STD).view(1, 3, 1, 1).to(device)

    @torch.no_grad()
    def extract_batch_stats(self, model, x, y, eps):
        """Extract statistical features (optionally class-agnostic)"""
        was_training = model.training
        model.eval()

        # Clean forward
        with UseAdvBN(model, False):
            logits_clean = model(Attacks.normalize(x, self.mean, self.std))

        probs_clean = F.softmax(logits_clean, dim=1)
        entropy_clean = -(probs_clean * torch.log(probs_clean + 1e-8)).sum(1)

        if self.cfg.class_agnostic or y is None:
            # Use pseudo-labels or entropy only
            pseudo_y = logits_clean.argmax(1)
            y_to_use = pseudo_y
        else:
            y_to_use = y

        # Weak adversarial probe
        x_adv = Attacks.fgsm(model, x, y_to_use, self.cfg.epsilon, self.mean, self.std, use_adv_bn=True)

        with UseAdvBN(model, True):
            logits_adv = model(Attacks.normalize(x_adv, self.mean, self.std))

        probs_adv = F.softmax(logits_adv, dim=1)
        entropy_adv = -(probs_adv * torch.log(probs_adv + 1e-8)).sum(1)

        features = [
            entropy_clean.mean(),
            entropy_clean.std(),
            entropy_adv.mean(),
            entropy_adv.std(),
            (entropy_adv - entropy_clean).mean(),
            (logits_adv.argmax(1) != y_to_use).float().mean(),
        ]

        if not self.cfg.class_agnostic and y is not None:
            features.extend([
                F.cross_entropy(logits_clean, y),
                F.cross_entropy(logits_adv, y),
            ])

        # Gradient norm stats
        delta = (x_adv - x).view(x.size(0), -1)
        features.extend([
            delta.norm(p=2, dim=1).mean(),
            delta.abs().max(dim=1)[0].mean()
        ])

        if was_training:
            model.train()

        return torch.tensor(features, device=self.device).detach()


# ========================= Multi-View Clusterer =========================
# ---------- STRATEGY CLUSTERERS (put these above MultiViewClusterer) ----------
from abc import ABC, abstractmethod
from sklearn.cluster import MiniBatchKMeans
import torch
from typing import Tuple, List, Dict, Optional
import pickle

# Optional density clustering (pip install hdbscan)
try:
    import hdbscan
    from hdbscan.prediction import approximate_predict as hdbscan_approx_predict
    HDBSCAN_AVAILABLE = True
except Exception:
    HDBSCAN_AVAILABLE = False


class Clusterer(ABC):
    @abstractmethod
    def fit(self, X: np.ndarray) -> np.ndarray:
        """Fit on X and return labels for X (length = X.shape[0])."""
        pass

    @abstractmethod
    def predict(self, X: np.ndarray) -> np.ndarray:
        """Predict labels for new X."""
        pass

    @abstractmethod
    def get_params(self) -> dict:
        """Return serializable params to persist the model."""
        pass

    @abstractmethod
    def set_params(self, params: dict):
        """Restore model parameters."""
        pass


class MiniBatchKMeansClusterer(Clusterer):
    def __init__(self, n_clusters: int, random_state: int = 42, max_iter: int = 20):
        self.n_clusters = n_clusters
        self.random_state = random_state
        self.max_iter = max_iter
        self.model = MiniBatchKMeans(n_clusters=n_clusters, random_state=random_state, max_iter=max_iter)

    def fit(self, X: np.ndarray) -> np.ndarray:
        labels = self.model.fit_predict(X)
        return labels

    def predict(self, X: np.ndarray) -> np.ndarray:
        return self.model.predict(X)

    def get_params(self) -> dict:
        return {"sk_model": pickle.dumps(self.model)}

    def set_params(self, params: dict):
        if "sk_model" in params:
            self.model = pickle.loads(params["sk_model"])


class HDBSCANClusterer(Clusterer):
    """Optional density-based clusterer (labels may contain -1 for noise)."""
    def __init__(self, min_cluster_size: int = 15, min_samples: Optional[int] = None):
        if not HDBSCAN_AVAILABLE:
            raise RuntimeError("HDBSCAN not available. pip install hdbscan")
        self.model = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples)

    def fit(self, X: np.ndarray) -> np.ndarray:
        labels = self.model.fit_predict(X)
        # replace noise with 0 to keep downstream code simple
        labels = np.where(labels < 0, 0, labels)
        return labels

    def predict(self, X: np.ndarray) -> np.ndarray:
        # approximate_predict returns (labels, strengths); map noise to 0
        labels, _ = hdbscan_approx_predict(self.model, X)
        labels = np.where(labels < 0, 0, labels)
        return labels

    def get_params(self) -> dict:
        return {"sk_model": pickle.dumps(self.model)}

    def set_params(self, params: dict):
        if "sk_model" in params:
            self.model = pickle.loads(params["sk_model"])


# ---------------------- MULTI-VIEW CLUSTERER (drop-in replacement) ----------------------
from sklearn.metrics import silhouette_score, adjusted_rand_score


class MultiViewClusterer:
    """
    Multi-view clustering with pluggable strategies for the stats and geometry views.

    Modes:
      - "consensus" (default): build (ks,kg) pair bins → optionally compress to cfg.K_final
      - "coreg": light co-regularization encouraging ks ≈ kg, then consensus & compression

    Exposes:
      - fit(stats_features, geom_features, batch_ids_list)
      - predict_final_labels(stats_features, geom_features)
      - transform_features(stats_features, geom_features)
      - quality_metrics(...)
      - get_transition()
      - save_state(path) / load_state(path)
    """

    def __init__(self,
                 cfg,
                 stats_clusterer: Optional[Clusterer] = None,
                 geom_clusterer: Optional[Clusterer] = None,
                 mode: str = "consensus"):
        self.cfg = cfg
        self.mode = mode  # "consensus" | "coreg"

        # strategies
        self.stats_clusterer = stats_clusterer or MiniBatchKMeansClusterer(getattr(cfg, "K_stats", 7))
        self.geom_clusterer  = geom_clusterer  or MiniBatchKMeansClusterer(getattr(cfg, "K_geom", 7))

        # learned scalers
        self.stats_mu = None
        self.stats_sigma = None
        self.geom_mu = None
        self.geom_sigma = None

        # learned results
        self.ks = None          # (B,) stats labels
        self.kg = None          # (B,) geom labels
        self.final_labels = None  # (B,) final labels
        self.pair_to_final: Dict[Tuple[int,int], int] = {}
        self.T = None           # (Kf, Kf) transition matrix

    # ---------- utils: stable normalization ----------

    @staticmethod
    def _stable_sigma(s: torch.Tensor, eps: float = 1e-8, thr: float = 1e-6) -> torch.Tensor:
        # Handle NaN values first
        s_clean = torch.nan_to_num(s, nan=1.0, posinf=1.0, neginf=1.0)
        s2 = torch.where(s_clean < thr, torch.ones_like(s_clean), s_clean)
        return s2 + eps

    def _normalize_fit(self, Xs: torch.Tensor, Xg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        self.stats_mu = Xs.mean(0)
        self.stats_sigma = self._stable_sigma(Xs.std(0))
        self.geom_mu = Xg.mean(0)
        self.geom_sigma = self._stable_sigma(Xg.std(0))
        Xs_n = (Xs - self.stats_mu) / self.stats_sigma
        Xg_n = (Xg - self.geom_mu) / self.geom_sigma
        return Xs_n, Xg_n

    def transform_features(self, Xs: torch.Tensor, Xg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply learned scalers to new features."""
        assert self.stats_mu is not None and self.geom_mu is not None, "Call fit() first."
        Xs_n = (Xs - self.stats_mu) / self.stats_sigma
        Xg_n = (Xg - self.geom_mu) / self.geom_sigma
        return Xs_n, Xg_n

    # ---------- view clustering & co-regularization ----------

    def _fit_view_clusterers(self, Xs_n: torch.Tensor, Xg_n: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        Xs_np = Xs_n.cpu().numpy()
        Xg_np = Xg_n.cpu().numpy()
        ks = self.stats_clusterer.fit(Xs_np)
        kg = self.geom_clusterer.fit(Xg_np)
        return torch.from_numpy(np.asarray(ks)).long(), torch.from_numpy(np.asarray(kg)).long()

    def _coreg_refine(self,
                      Xs_n: torch.Tensor, Xg_n: torch.Tensor,
                      ks: torch.Tensor, kg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encourage agreement between ks and kg with a light penalty.
        Does not assume K_stats == K_geom; uses modulo mapping if sizes differ.
        """
        alpha = getattr(self.cfg, "coreg_alpha", 0.2)
        iters = getattr(self.cfg, "coreg_iters", 5)

        Ks = int(max(ks.max().item() + 1, getattr(self.cfg, "K_stats", 7)))
        Kg = int(max(kg.max().item() + 1, getattr(self.cfg, "K_geom", 7)))

        # initialize centroids by group means
        def _centroids(X: torch.Tensor, labels: torch.Tensor, K: int) -> torch.Tensor:
            D = X.shape[1]
            C = torch.zeros(K, D, device=X.device)
            for k in range(K):
                m = (labels == k)
                if m.any():
                    C[k] = X[m].mean(0)
                else:
                    # random small init to avoid NaNs
                    C[k] = X.mean(0) + 1e-3 * torch.randn(D, device=X.device)
            return C

        Cs = _centroids(Xs_n, ks, Ks)
        Cg = _centroids(Xg_n, kg, Kg)
        ks_, kg_ = ks.clone(), kg.clone()

        for _ in range(iters):
            # distances
            d2s = ((Xs_n[:, None, :] - Cs[None, :, :])**2).sum(-1)  # (B, Ks)
            d2g = ((Xg_n[:, None, :] - Cg[None, :, :])**2).sum(-1)  # (B, Kg)

            # penalties to favor agreement
            pen_s = torch.full_like(d2s, alpha)
            pen_g = torch.full_like(d2g, alpha)
            for i in range(Xs_n.size(0)):
                pen_s[i, int(kg_[i].item()) % Ks] = 0.0
                pen_g[i, int(ks_[i].item()) % Kg] = 0.0

            ks_ = (d2s + pen_s).argmin(1)
            kg_ = (d2g + pen_g).argmin(1)

            # update centroids
            Cs = _centroids(Xs_n, ks_, Ks)
            Cg = _centroids(Xg_n, kg_, Kg)

        return ks_, kg_

    # ---------- consensus & compression ----------

    @staticmethod
    def _pair_bins(ks: torch.Tensor, kg: torch.Tensor) -> Tuple[torch.Tensor, np.ndarray, np.ndarray]:
        pairs_code = ks.cpu().numpy().astype(int) * 100000 + kg.cpu().numpy().astype(int)
        uniq, inv = np.unique(pairs_code, return_inverse=True)
        labels = torch.from_numpy(inv).long()  # 0..P-1
        return labels, uniq, pairs_code

    def _compress_to_K_final(self, labels: torch.Tensor, ks: torch.Tensor, kg: torch.Tensor) -> torch.Tensor:
        Kf_target = getattr(self.cfg, "K_final", int(labels.max().item()+1))
        if int(labels.max().item()+1) <= Kf_target:
            return labels
        coords = torch.stack([ks.float(), kg.float()], dim=1).cpu().numpy()
        reducer = MiniBatchKMeans(n_clusters=Kf_target, random_state=42, max_iter=20)
        labels_np = reducer.fit_predict(coords)
        return torch.from_numpy(labels_np).long()

    def _build_pair_to_final(self, ks: torch.Tensor, kg: torch.Tensor, labels: torch.Tensor):
        self.pair_to_final.clear()
        for s, g, f in zip(ks.tolist(), kg.tolist(), labels.tolist()):
            self.pair_to_final[(int(s), int(g))] = int(f)

    # ---------- transition matrix ----------

    @staticmethod
    def _transition(labels: torch.Tensor) -> torch.Tensor:
        if labels.numel() < 2:
            Kf = int(labels.max().item()+1) if labels.numel() > 0 else 1
            return torch.eye(Kf)
        Kf = int(labels.max().item() + 1)
        T = torch.zeros(Kf, Kf).float()
        for i in range(labels.numel() - 1):
            a = int(labels[i].item()); b = int(labels[i+1].item())
            T[a, b] += 1.0
        T = T / (T.sum(1, keepdim=True) + 1e-8)
        return T

    # ---------- PUBLIC API ----------

    def fit(self, stats_features: torch.Tensor, geom_features: torch.Tensor,
            batch_ids_list: Optional[List[List[int]]] = None, epoch=0):
        # Safety check
        if stats_features.numel() == 0 or geom_features.numel() == 0:
            logger.error("Empty features passed to clusterer")
            # Set minimal valid state
            self.stats_mu = torch.zeros(1, device=stats_features.device)
            self.stats_sigma = torch.ones(1, device=stats_features.device)
            self.geom_mu = torch.zeros(1, device=geom_features.device)
            self.geom_sigma = torch.ones(1, device=geom_features.device)
            self.final_labels = torch.zeros(1, dtype=torch.long)
            self.T = torch.eye(1)
            self.pair_to_final = {(0, 0): 0}
            return
        """
        Fit multi-view clustering; populate:
          self.ks, self.kg, self.final_labels, self.T, self.pair_to_final,
          and scalers (stats_mu/sigma, geom_mu/sigma).
        """
        # normalize & cluster each view
        Xs_n, Xg_n = self._normalize_fit(stats_features, geom_features)
        ks, kg = self._fit_view_clusterers(Xs_n, Xg_n)

        # optional co-regularization
        if self.mode == "coreg":
            ks, kg = self._coreg_refine(Xs_n, Xg_n, ks, kg)

        # consensus pairs → labels
        labels, uniq, _ = self._pair_bins(ks, kg)

        # compress to K_final if too many pair bins
        labels = self._compress_to_K_final(labels, ks, kg)

        # persist mapping (ks,kg)->final
        self._build_pair_to_final(ks, kg, labels)

        # store
        self.ks = ks; self.kg = kg
        self.final_labels = labels
        self.T = self._transition(labels)

    def predict_final_labels(self, stats_features: torch.Tensor, geom_features: torch.Tensor) -> torch.Tensor:
        """
        Predict final labels for new batches using stored scalers and mapping.
        If an unseen (ks,kg) pair occurs, fall back to 0.
        """
        assert self.stats_mu is not None and self.geom_mu is not None, "Call fit() first."
        Xs_n, Xg_n = self.transform_features(stats_features, geom_features)
        ks = np.asarray(self.stats_clusterer.predict(Xs_n.cpu().numpy())).astype(int)
        kg = np.asarray(self.geom_clusterer.predict(Xg_n.cpu().numpy())).astype(int)
        out = []
        for s, g in zip(ks, kg):
            out.append(self.pair_to_final.get((int(s), int(g)), 0))
        result = torch.tensor(out, dtype=torch.long)
        # Ensure device consistency
        if hasattr(stats_features, 'device'):
            result = result.to(stats_features.device)
        return result

    def quality_metrics(self,
                        stats_features: torch.Tensor,
                        geom_features: torch.Tensor) -> Dict[str, float]:
        """
        Returns guard-safe clustering quality metrics:
            silhouette_stats, silhouette_geom, ARI_stats_geom
        Silhouette is computed per-view on normalized features with their own labels.
        """
        out = {}
        try:
            Xs_n, Xg_n = self.transform_features(stats_features, geom_features)
            ks_np = self.ks.cpu().numpy() if self.ks is not None else None
            kg_np = self.kg.cpu().numpy() if self.kg is not None else None

            # Silhouette with better validation
            if ks_np is not None and len(np.unique(ks_np)) >= 2:
                try:
                    Xs_data = Xs_n.cpu().numpy()
                    if not np.isfinite(Xs_data).all():
                        out["silhouette_stats"] = 0.0
                    elif Xs_data.shape[0] > 1:
                        out["silhouette_stats"] = float(silhouette_score(Xs_data, ks_np, metric="euclidean"))
                    else:
                        out["silhouette_stats"] = 0.0
                except:
                    out["silhouette_stats"] = 0.0
            else:
                out["silhouette_stats"] = 0.0

            if kg_np is not None and len(np.unique(kg_np)) >= 2:
                try:
                    Xg_data = Xg_n.cpu().numpy()
                    if not np.isfinite(Xg_data).all():
                        out["silhouette_geom"] = 0.0
                    elif Xg_data.shape[0] > 1:
                        out["silhouette_geom"] = float(silhouette_score(Xg_data, kg_np, metric="euclidean"))
                    else:
                        out["silhouette_geom"] = 0.0
                except:
                    out["silhouette_geom"] = 0.0
            else:
                out["silhouette_geom"] = 0.0

            # ARI between views
            if ks_np is not None and kg_np is not None:
                try:
                    out["ARI_stats_geom"] = float(adjusted_rand_score(ks_np, kg_np))
                except:
                    out["ARI_stats_geom"] = 0.0
            else:
                out["ARI_stats_geom"] = 0.0

        except Exception:
            # stay robust
            out.setdefault("silhouette_stats", float("nan"))
            out.setdefault("silhouette_geom", float("nan"))
            out.setdefault("ARI_stats_geom", float("nan"))
        return out

    def get_transition(self) -> torch.Tensor:
        return self.T.clone() if self.T is not None else torch.eye(getattr(self.cfg, "K_final", 5))

    # ---------- persistence ----------

    def save_state(self, path: str):
        """Save clusterer state with robust scaler serialization"""
        # Prepare main state dict
        payload = {
            'pair_to_final': self.pair_to_final,
            'final_labels': self.final_labels.cpu().numpy() if self.final_labels is not None else None,
            'ks': self.ks.cpu().numpy() if self.ks is not None else None,
            'kg': self.kg.cpu().numpy() if self.kg is not None else None,
            'T': self.T.cpu().numpy() if self.T is not None else None,
            'mode': self.mode,
        }

        # Save scalers separately with joblib for robustness
        scalers_dict = {
            'stats_mu': self.stats_mu.cpu().numpy() if self.stats_mu is not None else None,
            'stats_sigma': self.stats_sigma.cpu().numpy() if self.stats_sigma is not None else None,
            'geom_mu': self.geom_mu.cpu().numpy() if self.geom_mu is not None else None,
            'geom_sigma': self.geom_sigma.cpu().numpy() if self.geom_sigma is not None else None,
        }

        try:
            # Use joblib for scalers (handles numpy arrays better)
            scalers_path = Path(path).with_suffix('.scalers.joblib')
            joblib.dump(scalers_dict, scalers_path)
            payload['scalers_path'] = str(scalers_path)
            logger.info(f"Saved scalers to {scalers_path}")
        except Exception as e:
            logger.warning(f"Scaler save with joblib failed: {e}, falling back to pickle")
            # Fallback: include in main pickle
            payload.update(scalers_dict)

        # Save clusterer models
        payload['stats_clusterer'] = self.stats_clusterer.get_params()
        payload['geom_clusterer'] = self.geom_clusterer.get_params()

        # Save main state
        with open(path, 'wb') as f:
            pickle.dump(payload, f)
        logger.info(f"Saved clusterer state to {path}")

    def load_state(self, path: str):
        """Load clusterer state with fallback for missing scalers"""
        with open(path, 'rb') as f:
            payload = pickle.load(f)

        # Load scalers (try joblib first, then fallback to embedded)
        scalers_loaded = False

        if 'scalers_path' in payload:
            scalers_path = Path(payload['scalers_path'])
            if scalers_path.exists():
                try:
                    scalers_dict = joblib.load(scalers_path)
                    self.stats_mu = torch.from_numpy(scalers_dict['stats_mu']) if scalers_dict['stats_mu'] is not None else None
                    self.stats_sigma = torch.from_numpy(scalers_dict['stats_sigma']) if scalers_dict['stats_sigma'] is not None else None
                    self.geom_mu = torch.from_numpy(scalers_dict['geom_mu']) if scalers_dict['geom_mu'] is not None else None
                    self.geom_sigma = torch.from_numpy(scalers_dict['geom_sigma']) if scalers_dict['geom_sigma'] is not None else None
                    scalers_loaded = True
                    logger.info(f"Loaded scalers from {scalers_path}")
                except Exception as e:
                    logger.warning(f"Failed to load joblib scalers: {e}")

        # Fallback: try to load from main payload
        if not scalers_loaded:
            for key in ['stats_mu', 'stats_sigma', 'geom_mu', 'geom_sigma']:
                if key in payload and payload[key] is not None:
                    if isinstance(payload[key], np.ndarray):
                        setattr(self, key, torch.from_numpy(payload[key]))
                    else:
                        setattr(self, key, payload[key])

            # Check if we got them
            if self.stats_mu is not None and self.geom_mu is not None:
                logger.info("Loaded scalers from embedded state")
                scalers_loaded = True

        if not scalers_loaded:
            logger.warning("No scalers found in saved state - will need to refit on first use")

        # Load other state
        def _to_torch(x):
            return torch.from_numpy(x) if x is not None else None

        self.pair_to_final = payload.get('pair_to_final', {})
        self.final_labels = _to_torch(payload.get('final_labels'))
        self.ks = _to_torch(payload.get('ks'))
        self.kg = _to_torch(payload.get('kg'))
        self.T = _to_torch(payload.get('T'))
        self.mode = payload.get('mode', 'consensus')

        # Restore clusterer models
        self.stats_clusterer.set_params(payload.get('stats_clusterer', {}))
        self.geom_clusterer.set_params(payload.get('geom_clusterer', {}))

    def extract_confidence_patterns(self, train_loader):
        """Extract prediction confidence distributions"""
        features = []
        self.model.eval()

        with torch.no_grad():
            for inputs, targets in train_loader:
                inputs = inputs.cuda()
                outputs = self.model(inputs)
                probs = F.softmax(outputs, dim=1)

                max_prob, _ = probs.max(dim=1)
                entropy = -(probs * probs.log()).sum(dim=1)

                batch_features = torch.stack([max_prob, entropy], dim=1)
                features.append(batch_features.cpu())

        self.model.train()
        return torch.cat(features)

# ========================= Group Batch Sampler =========================
class GroupBatchSampler(BatchSampler):
    def __init__(self, groups: Dict[int, List[int]], order: List[int],
                batch_size: int, dataset, drop_small: bool = False):
        """
        groups[k]: flat list of TRUE dataset indices
        dataset: the dataset being used (may be a Subset)
        """
        self.batch_size = batch_size
        self.drop_small = drop_small

        # Create reverse mapping if dataset is a Subset
        if hasattr(dataset, 'indices'):
            # It's a Subset - create mapping from true indices to subset indices
            true_to_subset = {true_idx: i for i, true_idx in enumerate(dataset.indices)}
            # Convert groups to use subset indices
            self.groups = {}
            for k, true_indices in groups.items():
                subset_indices = []
                for true_idx in true_indices:
                    if true_idx in true_to_subset:
                        subset_indices.append(true_to_subset[true_idx])
                if subset_indices:  # Only add non-empty groups
                    self.groups[k] = subset_indices
        else:
            # Not a subset, use indices as-is
            self.groups = groups

        # Filter order ONCE to valid groups and cap by actual capacity
        valid_groups = set(self.groups.keys())

        # Calculate how many batches each group can actually provide
        group_capacity = {}
        for g in valid_groups:
            # Each group can provide at most this many full batches
            group_capacity[g] = max(1, len(self.groups[g]) // batch_size)
            if not drop_small and len(self.groups[g]) % batch_size > 0:
                group_capacity[g] += 1  # Account for partial batch

        # Filter the order based on validity and capacity
        filtered_order = []
        group_counts = {g: 0 for g in valid_groups}

        for g in order:
            if g in valid_groups:
                # Only add if this group hasn't exceeded its capacity
                if group_counts[g] < group_capacity[g]:
                    filtered_order.append(g)
                    group_counts[g] += 1

        self.order = filtered_order

        expected_batches = math.ceil(sum(len(g) for g in self.groups.values()) / batch_size)
        if len(self.order) < len(order) * 0.5:  # Only warn if lost more than 50%
            logger.warning(f"Major order loss: {len(order)} -> {len(self.order)}")
            # Try to recover by cycling through available groups
            while len(self.order) < min(len(order), expected_batches * 0.9):
                added_any = False
                for g in sorted(self.groups.keys()):  # Use self.groups.keys() instead of valid_groups
                    if group_counts[g] < group_capacity[g]:
                        self.order.append(g)
                        group_counts[g] += 1
                        added_any = True
                        if len(self.order) >= min(len(order), expected_batches * 0.9):
                            break
                if not added_any:  # Prevent infinite loop if all groups exhausted
                    break

        # Assert cluster coverage
        order_clusters = set(self.order)
        available_clusters = set(self.groups.keys())
        if not order_clusters.issubset(available_clusters):
            invalid = order_clusters - available_clusters
            raise ValueError(f"Order contains invalid clusters: {invalid}")

        logger.info(f"GroupBatchSampler ready: {len(self.order)} batches covering clusters {sorted(order_clusters)}")


    def __iter__(self):
        import random

        # Prepare all groups with shuffling
        shuffled_groups = {}
        for k, indices in self.groups.items():
            shuffled_indices = indices.copy()
            random.shuffle(shuffled_indices)
            shuffled_groups[k] = shuffled_indices

        pointers = {k: 0 for k in self.groups}

        # Create interleaved batches
        for group_id in self.order:
            if group_id not in shuffled_groups:
                continue

            indices = shuffled_groups[group_id]
            ptr = pointers[group_id]

            # Get next batch from this group
            batch = []
            while len(batch) < self.batch_size and ptr < len(indices):
                batch.append(indices[ptr])
                ptr += 1

            pointers[group_id] = ptr

            if len(batch) > 0:
                # Only yield full batches or last batch if not dropping small
                if len(batch) == self.batch_size or not self.drop_small:
                    yield batch

    def __len__(self):
        total_items = sum(len(indices) for indices in self.groups.values())
        return (total_items + self.batch_size - 1) // self.batch_size if not self.drop_small else total_items // self.batch_size


class AdaptiveMultiViewClusterer(MultiViewClusterer):
    """
    Extends MultiViewClusterer with:
    1. Momentum-based feature accumulation
    2. Hierarchical conflict resolution
    3. Delayed clustering after warmup
    """

    def __init__(self, cfg):
        super().__init__(cfg)  # Initialize parent class

        # Additional attributes for adaptive clustering
        self.momentum = 0.9
        self.feature_bank_stats = None
        self.feature_bank_geom = None
        self.warmup_epochs = 3
        self.is_warmed_up = False
        self.clustering_history = []
        self.view_confidences = {'stats': [], 'geom': []}

    def fit(self, stats_features, geom_features, batch_ids_list=None, epoch=0):
        """Override fit with progressive clustering"""

        # During warmup, return uniform clustering
        if epoch < self.warmup_epochs:
            n_batches = len(stats_features)
            dev = stats_features.device
            self.final_labels = torch.tensor([i % self.cfg.n_clusters for i in range(n_batches)], device=dev, dtype=torch.long)
            self.T = torch.eye(self.cfg.n_clusters, device=dev)
            # Also set minimal scalers to avoid downstream errors
            self.stats_mu = stats_features.mean(0)
            self.stats_sigma = stats_features.std(0).clamp(min=1e-6)
            self.geom_mu = geom_features.mean(0)
            self.geom_sigma = geom_features.std(0).clamp(min=1e-6)
            return

        # Update feature banks with momentum
        if self.feature_bank_stats is None:
            self.feature_bank_stats = stats_features.clone()
            self.feature_bank_geom = geom_features.clone()
        else:
            self.feature_bank_stats = (self.momentum * self.feature_bank_stats +
                                       (1 - self.momentum) * stats_features)
            self.feature_bank_geom = (self.momentum * self.feature_bank_geom +
                                      (1 - self.momentum) * geom_features)

        # Now call parent's fit on the accumulated features
        super().fit(self.feature_bank_stats, self.feature_bank_geom, batch_ids_list)

        # After parent's fit, apply hierarchical conflict resolution
        if hasattr(self, 'ks') and hasattr(self, 'kg'):
            # Check view agreement
            agreement_rate = (self.ks == self.kg).float().mean()

            if agreement_rate < 0.7:  # Low agreement threshold
                # Apply conflict resolution
                self._resolve_conflicts()

    def _resolve_conflicts(self):
        """Resolve conflicts between views using confidence scores"""
        # Compute silhouette scores for each view
        Xs_n, Xg_n = self.transform_features(self.feature_bank_stats, self.feature_bank_geom)

        try:
            sil_stats = silhouette_score(Xs_n.cpu().numpy(), self.ks.cpu().numpy())
        except:
            sil_stats = 0.0

        try:
            sil_geom = silhouette_score(Xg_n.cpu().numpy(), self.kg.cpu().numpy())
        except:
            sil_geom = 0.0

        # For samples where views disagree, use higher confidence view
        disagreement_mask = (self.ks != self.kg)
        if disagreement_mask.any() and sil_stats != sil_geom:
            if sil_stats > sil_geom:
                # Trust stats view more
                self.final_labels = self.ks.clone()
            else:
                # Trust geometry view more
                self.final_labels = self.kg.clone()

            # Rebuild transition matrix with resolved labels
            self.T = self._transition(self.final_labels)


# ========================= Order Scheduler =========================
class OrderScheduler:
    """Schedule batch order based on UCB bandit with cycle testing"""

    def __init__(self, n_clusters, cfg):
        self.n_clusters = n_clusters
        self.cfg = cfg

        # Bandit state
        self.rewards = defaultdict(list)
        self.counts = defaultdict(int)
        self.total_pulls = 0

        # Per-group probe results
        self.probe_results = defaultdict(list)

        # Transition tracking
        self.last_cluster = None
        self.transition_counts = np.zeros((n_clusters, n_clusters))

    def build_natural_order_with_fallbacks(self, groups, T, epoch):
        """Natural progression with rare intelligent fallbacks"""
        n_samples = sum(len(v) for v in groups.values())
        n_batches = math.ceil(n_samples / self.cfg.batch_size)

        order = []
        current = self.get_foundation_cluster(T)
        consecutive_same = 0

        for i in range(n_batches):
            order.append(current)

            if i < n_batches - 1:
                # Get transition probabilities from natural flow
                if T is not None and current < len(T):
                    T_row = T[current].cpu().numpy() if hasattr(T[current], 'cpu') else T[current]

                    # If stuck (too many consecutive same), force exploration VERY rarely
                    if consecutive_same >= 4:  # Rare fallback as requested
                        logger.debug(f"Breaking stuck pattern at cluster {current} after {consecutive_same} steps")
                        available_clusters = [c for c in range(len(groups)) if c != current and len(groups[c]) > 0]
                        if available_clusters:
                            next_cluster = np.random.choice(available_clusters)
                        else:
                            next_cluster = current
                        consecutive_same = 0
                    else:
                        # Follow natural progression with minimal exploration
                        epsilon = 0.02  # Very small exploration
                        T_row = (1 - epsilon) * T_row + epsilon / len(T_row)
                        T_row = T_row / (T_row.sum() + 1e-8)

                        next_cluster = np.random.choice(len(T_row), p=T_row)

                        if next_cluster == current:
                            consecutive_same += 1
                        else:
                            consecutive_same = 0
                else:
                    # Fallback if no T matrix
                    next_cluster = np.random.choice(list(groups.keys()))
                    consecutive_same = 0

                current = next_cluster

        return order

    def get_foundation_cluster(self, T):
        """Find cluster that's foundational (teaches others, needs little preparation)"""
        if T is None or len(T) == 0:
            return 0

        T_np = T.cpu().numpy() if hasattr(T, 'cpu') else T

        # Foundation: high out-degree (prepares others), low in-degree (needs little prep)
        out_degrees = T_np.sum(1)
        in_degrees = T_np.sum(0)

        foundation_score = out_degrees - 0.3 * in_degrees
        return int(np.argmax(foundation_score))


    def _build_order(self, epoch):
        """Student uses teacher's discovered best ordering"""
        n_samples = sum(len(v) for v in self.groups.values())  # KEEP THIS LINE
        n_batches = math.ceil(n_samples / self.cfg.batch_size)

        # FORCE SINGLE CLUSTER FOR TESTING
        if hasattr(self.cfg, 'single_cluster_id') and self.cfg.order_mode == "single_cluster":
            cluster_id = self.cfg.single_cluster_id
            logger.info(f"Student forcing single cluster {cluster_id} for all {n_batches} batches")
            return [cluster_id] * n_batches

        # ADD THIS BLOCK FOR CUSTOM ORDER SUPPORT
        if self.cfg.order_mode == "custom" and hasattr(self.cfg, 'custom_order') and self.cfg.custom_order:
            order_sequence = [int(x) for x in self.cfg.custom_order.split(',')]
            order = []
            for i in range(n_batches):
                order.append(order_sequence[i % len(order_sequence)])
            logger.info(f"Using custom order: {order_sequence} repeated for {n_batches} batches")
            return order

        # Calculate TRUE capacity per cluster (not oversampled)
        cluster_capacity = {}
        for k, indices in groups.items():
            # Each cluster can provide exactly this many batches
            cluster_capacity[k] = math.ceil(len(indices) / self.cfg.batch_size)

        total_capacity = sum(cluster_capacity.values())
        n_batches = min(n_batches, total_capacity)  # Can't exceed actual capacity

        order = []
        cluster_usage = {k: 0 for k in groups.keys()}
        last_cluster = None

        for i in range(n_batches):
            # Find clusters with remaining capacity
            available = [k for k, cap in cluster_capacity.items()
                        if cluster_usage[k] < cap]

            if not available:
                logger.warning(f"All clusters exhausted at step {i}/{n_batches}")
                break

            # Prevent immediate repeats when possible
            if last_cluster is not None and len(available) > 1:
                non_repeat = [k for k in available if k != last_cluster]
                if non_repeat:
                    available = non_repeat

            # Select next cluster
            if epoch < 3:
                # Pure exploration phase
                next_cluster = np.random.choice(available)
            else:
                # UCB selection
                ucb_scores = []
                for k in available:
                    if self.counts[k] == 0:
                        score = float('inf')
                    else:
                        mean_reward = np.mean(self.rewards[k]) if self.rewards[k] else 0
                        exploration = self.cfg.ucb_c * np.sqrt(
                            2 * np.log(max(1, i + 1)) / max(1, self.counts[k])
                        )
                        score = mean_reward + exploration
                    ucb_scores.append((k, score))

                # Sort by score and pick best
                ucb_scores.sort(key=lambda x: x[1], reverse=True)
                next_cluster = ucb_scores[0][0]

            order.append(next_cluster)
            cluster_usage[next_cluster] += 1
            last_cluster = next_cluster

        # Verify order uses all clusters reasonably
        used_clusters = set(order)
        if len(used_clusters) < len(groups):
            logger.warning(f"Order only uses {len(used_clusters)}/{len(groups)} clusters")

        logger.info(f"Built order of {len(order)} batches, cluster usage: {cluster_usage}")
        return order


    def update_reward(self, cluster_id, reward_dict):
        """Update cluster reward with multi-objective signal"""
        if isinstance(reward_dict, dict):
            # New: weighted combination of robust acc gain and steps saved
            delta_robust = reward_dict.get('delta_robust', 0)
            steps_saved_ratio = reward_dict.get('steps_saved_ratio', 0)

            # Weight: 0.7 for accuracy, 0.3 for efficiency
            combined_reward = 0.7 * delta_robust + 0.3 * steps_saved_ratio
        else:
            # Fallback for old code
            combined_reward = reward_dict

        self.rewards[cluster_id].append(combined_reward)
        self.counts[cluster_id] += 1
        self.total_pulls += 1

        if self.last_cluster is not None:
            self.transition_counts[self.last_cluster, cluster_id] += 1
        self.last_cluster = cluster_id

    def update_transition_reward(self, from_cluster, to_cluster, robust_acc_delta):
        """Track rewards for specific transitions"""
        if not hasattr(self, 'transition_rewards'):
            self.transition_rewards = defaultdict(list)

        self.transition_rewards[(from_cluster, to_cluster)].append(robust_acc_delta)

        # Also update the main transition scheduler if it exists
        if hasattr(self, 'transition_scheduler'):
            self.transition_scheduler.transition_rewards[(from_cluster, to_cluster)].append(robust_acc_delta)

    def update_probe_result(self, cluster_id, robust_acc, pgd_calls):
        """Update probe-based reward for a cluster"""
        if pgd_calls > 0:
            efficiency = robust_acc / (pgd_calls / 1e5)
            self.probe_results[cluster_id].append(efficiency)

    def select_next_cluster(self, available_clusters=None):
        """Select next cluster using UCB without forced uniformity"""
        if available_clusters is None:
            available_clusters = list(range(self.n_clusters))

        # Minimal exploration requirement (not forcing uniformity)
        min_samples = 2  # Just ensure we've tried each cluster twice
        undersampled = [k for k in available_clusters if self.counts[k] < min_samples]

        if undersampled and self.total_pulls < self.n_clusters * min_samples:
            return np.random.choice(undersampled)

        # After minimal exploration, use pure UCB
        ucb_scores = []
        for k in available_clusters:
            if self.counts[k] == 0:
                ucb_scores.append(float('inf'))
            else:
                mean_reward = np.mean(self.rewards[k]) if self.rewards[k] else 0
                exploration = self.cfg.ucb_c * np.sqrt(
                    2 * np.log(max(1, self.total_pulls)) / max(1, self.counts[k])
                )
                ucb_scores.append(mean_reward + exploration)

        return available_clusters[int(np.argmax(ucb_scores))]

    def get_order(self, n_batches, mode="ucb", cycle=None):
        """Generate order for n_batches"""
        order = []

        if mode == "cycle" and cycle:
            # Follow specific cycle
            for i in range(n_batches):
                order.append(cycle[i % len(cycle)])
        elif mode == "reverse_cycle" and cycle:
            # Reverse of cycle
            rev_cycle = list(reversed(cycle))
            for i in range(n_batches):
                order.append(rev_cycle[i % len(rev_cycle)])
        elif mode == "random":
            # Random order
            for _ in range(n_batches):
                order.append(np.random.randint(self.n_clusters))
        elif mode == "uniform":
            # Uniform round-robin
            for i in range(n_batches):
                order.append(i % self.n_clusters)
        elif mode == "single" and self.cfg.single_cluster_id < self.n_clusters:
            # Single cluster
            order = [self.cfg.single_cluster_id] * n_batches
        else:
            # UCB selection (default)
            for _ in range(n_batches):
                cluster = self.select_next_cluster()
                order.append(cluster)

        # Ensure minimum random exploration
        # n_random = int(n_batches * self.cfg.random_batch_ratio)
        # random_positions = np.random.choice(n_batches, n_random, replace=False)
        # for pos in random_positions:
        #     order[pos] = np.random.randint(self.n_clusters)

        return order

class TransitionScheduler:
    """Schedule batches based on transition matrix from clustering"""

    def __init__(self, n_clusters, cfg):
        self.n_clusters = n_clusters
        self.cfg = cfg
        self.T = np.eye(n_clusters) / n_clusters  # Initial uniform
        self.counts = np.zeros((n_clusters, n_clusters))
        self.current = 0
        self.recording_enabled = True
        self.transition_rewards = defaultdict(list)

        # Determine when to start recording based on mode
        if cfg.t_matrix_mode == "throughout":
            self.recording_enabled = True
        elif cfg.t_matrix_mode == "late":
            self.recording_enabled = False  # Will enable later

    def update_transition(self, from_cluster, to_cluster, epoch=None):
        """Update transition counts based on mode"""

        # Check if we should record based on mode
        if self.cfg.t_matrix_mode == "late":
            if epoch is not None and epoch >= self.cfg.t_matrix_start_epoch:
                self.recording_enabled = True

        if self.cfg.t_matrix_mode == "converged":
            # Only record in final epochs
            if epoch is not None:
                epochs_remaining = self.cfg.epochs - epoch
                if epochs_remaining <= self.cfg.t_matrix_convergence_window:
                    self.recording_enabled = True

        # Record if enabled
        if self.recording_enabled:
            self.counts[from_cluster, to_cluster] += 1

            # Log periodically for verification
            total_transitions = self.counts.sum()
            if total_transitions % 100 == 0:
                logger.debug(f"T-matrix mode '{self.cfg.t_matrix_mode}': "
                           f"Recorded {int(total_transitions)} transitions "
                           f"(epoch {epoch})")

    def set_recording_mode(self, mode, current_epoch, cfg):
            """Control when to record transitions based on pedagogical theory"""
            prev_enabled = self.recording_enabled

            if mode == "throughout":
                self.recording_enabled = True
            elif mode == "late":
                self.recording_enabled = (current_epoch >= cfg.t_matrix_start_epoch)
                if current_epoch == cfg.t_matrix_start_epoch and not prev_enabled:
                    logger.info(f"T-matrix recording ENABLED at epoch {current_epoch} (late mode)")
            elif mode == "converged":
                epochs_remaining = cfg.epochs - current_epoch
                self.recording_enabled = (epochs_remaining <= cfg.t_matrix_convergence_window)
                if epochs_remaining == cfg.t_matrix_convergence_window and not prev_enabled:
                    logger.info(f"T-matrix recording ENABLED - final {cfg.t_matrix_convergence_window} epochs (converged mode)")

    def compute_T(self, normalize=True, use_rewards=True):
        """Compute transition matrix with optional reward weighting"""
        T = np.zeros((self.n_clusters, self.n_clusters))

        # Check if we should use rewards
        if use_rewards and hasattr(self, 'transition_rewards') and self.transition_rewards:
            # Weight each transition by its average success
            for i in range(self.n_clusters):
                for j in range(self.n_clusters):
                    base_count = self.counts[i, j]

                    if (i, j) in self.transition_rewards and self.transition_rewards[(i, j)]:
                        # Get recent rewards for this transition
                        recent_rewards = self.transition_rewards[(i, j)][-20:]
                        avg_improvement = np.mean(recent_rewards)

                        # Convert to weight using sigmoid (keeps positive)
                        # Positive improvements get weight > 1, negative get < 1
                        weight = 1 / (1 + np.exp(-avg_improvement * 10))
                        T[i, j] = base_count * weight
                    else:
                        # No reward data - use small weight for unexplored
                        T[i, j] = base_count * 0.1 if base_count > 0 else 0
        else:
            # Fallback to pure count-based
            T = self.counts.copy()

        # Normalize if requested
        if normalize:
            row_sums = T.sum(axis=1, keepdims=True)
            row_sums[row_sums == 0] = 1  # Avoid division by zero
            T = T / row_sums

        return T

    def get_next_cluster(self, mode='natural'):
        """Get next cluster based on transition matrix"""
        if mode == 'natural':
            probs = self.T[self.current]
        elif mode == 'reverse':
            probs = self.T.T[self.current]
        else:  # random
            return np.random.randint(self.n_clusters)

        if probs.sum() == 0:
            next_cluster = np.random.randint(self.n_clusters)
        else:
            probs = probs / probs.sum()
            next_cluster = np.random.choice(self.n_clusters, p=probs)

        self.current = next_cluster
        return next_cluster

# ========================= LOAT Trainer =========================
class LOATTrainer:
    """Main trainer for Latent-Order Adversarial Training"""

    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.current_epoch = 0
        self.mean = torch.tensor(CIFAR10_MEAN).view(1, 3, 1, 1).to(self.device)
        self.std = torch.tensor(CIFAR10_STD).view(1, 3, 1, 1).to(self.device)

        self.transition_scheduler = TransitionScheduler(cfg.n_clusters, cfg)

        # Models first
        self.model = self._build_model()
        self.autoencoder = DenoisingAutoencoder(cfg.ae_latent_dim).to(self.device)

        self.batches_per_epoch = None

        # Extractors
        self.geom_profiler = GeometryProfiler(cfg, self.device)
        self.stats_profiler = StatsProfiler(self.device, cfg)

        # NOW create aliases
        self.geom = self.geom_profiler
        self.stats = self.stats_profiler
        self.ae = self.autoencoder
        self.difficulty_profiles = {}  # Initialize even if not using recipe
        self.uncertain_indices = []

        self.is_student_mode = False  # Will be set by recipe loading
        self.student_difficulty_scale = 1.0
        self.student_cluster_performance = defaultdict(list)
        self._fallback_cluster_idx = 0
        self.discovery_feature_mode = 'adaptive'  # Initialize this
        self.feature_weights = {}  # Initialize this too
        self.discovered_best_order = []

        # EMA model
        self.model_ema = None
        if cfg.ema_enabled:
            self.model_ema = copy.deepcopy(self.model)
            self.model_ema.eval()
            for p in self.model_ema.parameters():
                p.requires_grad = False


        # Clustering
        if HDBSCAN_AVAILABLE and getattr(cfg, 'use_hdbscan', False):
            from sklearn.cluster import MiniBatchKMeans
            stats_clusterer = HDBSCANClusterer(min_cluster_size=15)
            geom_clusterer = HDBSCANClusterer(min_cluster_size=15)
            self.clusterer = MultiViewClusterer(
                cfg,
                stats_clusterer=stats_clusterer,
                geom_clusterer=geom_clusterer,
                mode=cfg.mv_mode
            )
        else:
            self.clusterer = AdaptiveMultiViewClusterer(cfg)

        self.scheduler = OrderScheduler(cfg.n_clusters, cfg)

        # Groups
        self.groups = None
        self.batch_labels = None

        # Optimizers
        self.optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=cfg.lr,
            momentum=cfg.momentum,
            weight_decay=cfg.weight_decay
        )
        self.ae_optimizer = torch.optim.Adam(
            self.autoencoder.parameters(),
            lr=cfg.ae_lr
        )

        # Scheduler
        self.lr_scheduler = self._build_lr_scheduler()

        # Metrics
        self.metrics = defaultdict(list)
        self.pgd_calls_train = 0
        self.pgd_calls_epoch = 0
        self.pgd_calls_eval = 0
        self.global_step = 0
        self.transition_qualities = defaultdict(list)
        self.prev_loss_for_transition = float('inf')
        self.prev_robust_for_transition = 0

        self.transition_rewards = defaultdict(list)  # Track transition success
        self.prev_robust = 0
        self.prev_loss = float('inf')

        # Block tracking
        self.current_block = []
        self.block_start_robust = 0
        self.block_start_pgd = 0

        self.feature_mode = None  # "ae", "model", or "combined"
        self.feature_dim = None   # Actual dimension used

        # HITL reporting
        self.hitl_reports = []

        # Cluster stability tracking
        self.prev_labels = None


        self.detailed_metrics = {
            # Per-epoch metrics
            'epoch': [],
            'train_loss': [],
            'train_clean_acc': [],
            'train_robust_acc': [],
            'val_clean_acc': [],
            'val_robust_acc': [],
            'test_clean_acc': [],
            'test_robust_acc': [],

            # Efficiency metrics
            'pgd_calls_cumulative': [],
            'pgd_steps_saved': [],  # from early stopping
            'efficiency_score': [],
            'time_per_epoch': [],
            'cumulative_time': [],

            # Clustering metrics
            'n_clusters_discovered': [],
            'cluster_sizes': [],
            'cluster_stability_ari': [],
            'silhouette_stats': [],
            'silhouette_geom': [],
            'multiview_agreement': [],

            # Group-wise performance
            'group_robust_acc': defaultdict(list),
            'group_loss': defaultdict(list),
            'group_sample_efficiency': defaultdict(list),

            # Adversarial metrics
            'avg_perturbation_norm': [],
            'successful_attack_rate': [],
            'margin_loss': [],
            'kl_divergence': [],

            # Scheduling metrics
            'group_selection_counts': defaultdict(int),
            'group_rewards': defaultdict(list),
            'ucb_scores': defaultdict(list),
            'transition_matrix': [],

            # Gradient statistics
            'grad_norm_mean': [],
            'grad_norm_std': [],
            'weight_norm': [],

            # Learning dynamics
            'learning_rate': [],
            'beta_value': [],
            'ema_distance': [],
        }

        # Per-batch tracking
        self.batch_metrics = {
            'losses': deque(maxlen=100),
            'clean_acc': deque(maxlen=100),
            'robust_acc': deque(maxlen=100),
            'pgd_steps': deque(maxlen=100),
        }

        self.time_series_log = []
        self.training_start_wall_time = None

        self.epoch_start_time = None
        self.training_start_time = None

        # Add cluster difficulty tracking
        self.cluster_difficulties = {}  # cluster_id -> difficulty (0-1)
        self.cluster_robust_history = defaultdict(list)
        self.cluster_pgd_efficiency = defaultdict(list)
        self.pgd_calls_per_cluster = defaultdict(int)
        self.robust_acc_per_cluster = defaultdict(list)

    def refine_recipe_with_performance(self):
        """Refine recipe based on what actually worked during training"""
        logger.info("Refining recipe based on observed performance...")

        # 1. Identify clusters that led to good outcomes
        cluster_performance = {}
        for cid in range(self.cfg.n_clusters):
            if cid in self.robust_acc_per_cluster and self.robust_acc_per_cluster[cid]:
                recent = self.robust_acc_per_cluster[cid][-20:]
                cluster_performance[cid] = {
                    'avg_robust': np.mean(recent),
                    'improvement_rate': (recent[-1] - recent[0]) / len(recent) if len(recent) > 1 else 0
                }

        # 2. Update difficulty profiles based on actual performance
        for cid, perf in cluster_performance.items():
            if cid in self.difficulty_profiles:
                # Adjust difficulty based on observed vs expected
                observed_difficulty = 1.0 - perf['avg_robust']
                expected_difficulty = self.difficulty_profiles[cid].get('overall_difficulty', 0.5)

                # Blend observed with expected
                self.difficulty_profiles[cid]['overall_difficulty'] = (
                    0.3 * expected_difficulty + 0.7 * observed_difficulty
                )
                self.difficulty_profiles[cid]['observed_robust'] = perf['avg_robust']
                self.difficulty_profiles[cid]['improvement_rate'] = perf['improvement_rate']

        # 3. Create refined transition recommendations
        recommended_paths = []
        for (from_c, to_c), qualities in self.transition_qualities.items():
            if len(qualities) >= 10:
                recent = qualities[-10:]
                avg_gain = np.mean([q['robust_gain'] for q in recent])
                avg_steps = np.mean([q['steps_used'] for q in recent])

                if avg_gain > 0.02:  # Significant improvement
                    recommended_paths.append({
                        'from': from_c,
                        'to': to_c,
                        'gain': avg_gain,
                        'steps': avg_steps,
                        'priority': avg_gain / (avg_steps / 10)  # Efficiency score
                    })

        # Sort by priority
        recommended_paths.sort(key=lambda x: x['priority'], reverse=True)

        return {
            'cluster_performance': cluster_performance,
            'refined_difficulties': self.difficulty_profiles,
            'recommended_paths': recommended_paths[:10]  # Top 10 paths
        }


    def train_simclr_encoder(self, epochs=50):
        """Pre-train SimCLR encoder for better initial features"""
        logger.info("Pre-training SimCLR encoder...")

        # Skip if CUDA issues detected
        try:
            test_tensor = torch.randn(2, 3, 32, 32).to(self.device)
            test_output = F.normalize(test_tensor, dim=1)
            del test_tensor, test_output
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
        except:
            logger.warning("CUDA issues detected, skipping SimCLR pre-training")
            return None

        self.simclr = SimCLREncoder(base_model='resnet18').to(self.device)
        optimizer = torch.optim.Adam(self.simclr.parameters(), lr=0.001)

        # Use the standard CIFAR-10 training augmentation for simplicity
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        simclr_dataset = torchvision.datasets.CIFAR10(
            root=self.cfg.data_root, train=True, transform=transform
        )
        simclr_loader = DataLoader(simclr_dataset, batch_size=128, shuffle=True, num_workers=0)

        logger.info(f"SimCLR training on {len(simclr_dataset)} samples")

        for epoch in range(min(epochs, 5)):  # Limit to 5 epochs for testing
            self.simclr.train()
            total_loss = 0

            for batch_idx, (x, _) in enumerate(simclr_loader):
                if batch_idx > 100:  # Limit batches for quick testing
                    break

                x = x.to(self.device)

                # Simple self-supervised task: predict rotation
                x_90 = torch.rot90(x, 1, [2, 3])
                x_180 = torch.rot90(x, 2, [2, 3])
                x_270 = torch.rot90(x, 3, [2, 3])

                x_all = torch.cat([x, x_90, x_180, x_270], dim=0)
                labels = torch.cat([
                    torch.zeros(x.size(0)),
                    torch.ones(x.size(0)),
                    2 * torch.ones(x.size(0)),
                    3 * torch.ones(x.size(0))
                ]).long().to(self.device)

                # Get features and classify rotation
                h, z = self.simclr(x_all)

                # Simple rotation prediction loss
                rotation_head = nn.Linear(z.size(1), 4).to(self.device)
                logits = rotation_head(z)
                loss = F.cross_entropy(logits, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

                if batch_idx % 20 == 0:
                    logger.info(f"SimCLR Epoch {epoch+1}/{min(epochs, 5)}, "
                              f"Batch {batch_idx}, Loss: {loss.item():.4f}")

            logger.info(f"SimCLR Epoch {epoch+1} complete, Avg Loss: {total_loss/(batch_idx+1):.4f}")

        return self.simclr

    def train_teacher_two_phase(self, train_loader, val_loader, test_loader):
        """Two-phase teacher training: first train, then discover"""

        # Initialize loaders and datasets first
        self.train_loader = train_loader
        self.train_set = train_loader.dataset
        self.val_loader = val_loader

        # Phase 1: Pre-train SimCLR if enabled
        if self.cfg.use_simclr:
            logger.info("Pre-training SimCLR encoder for feature stabilization...")
            self.simclr = SimCLREncoder(base_model='resnet18').to(self.device)

            simclr_optimizer = torch.optim.Adam(self.simclr.parameters(), lr=0.001)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                simclr_optimizer, T_max=50
            )

            simclr_loader = DataLoader(
                self.train_set,
                batch_size=min(512, self.cfg.batch_size * 4),  # Larger batch for contrastive
                shuffle=True,
                num_workers=0,
                drop_last=True  # Important for batch norm
            )

            for epoch in range(50):  # 50 epochs minimum
                epoch_loss = 0
                for batch_idx, (x, _) in enumerate(simclr_loader):
                    # Don't limit to 50 batches - train on full dataset
                    x = x.to(self.device)

                    # Better augmentations
                    x_aug1 = x + torch.randn_like(x) * 0.1
                    x_aug1 = transforms.functional.adjust_brightness(x_aug1, 0.8 + torch.rand(1).item() * 0.4)

                    x_aug2 = torch.flip(x, dims=[3])
                    x_aug2 = x_aug2 + torch.randn_like(x_aug2) * 0.1

                    h1, z1 = self.simclr(x_aug1)
                    h2, z2 = self.simclr(x_aug2)

                    # NT-Xent loss
                    batch_size = x.size(0)
                    z = torch.cat([z1, z2], dim=0)
                    sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
                    sim = sim / 0.07  # Temperature

                    # Create labels
                    labels = torch.cat([
                        torch.arange(batch_size) + batch_size,
                        torch.arange(batch_size)
                    ], dim=0).to(self.device)

                    loss = F.cross_entropy(sim, labels)

                    simclr_optimizer.zero_grad()
                    loss.backward()
                    simclr_optimizer.step()
                    epoch_loss += loss.item()

                scheduler.step()
                if (epoch + 1) % 10 == 0:
                    logger.info(f"SimCLR epoch {epoch+1}/50, Loss: {epoch_loss/len(simclr_loader):.4f}")

            logger.info("SimCLR pre-training complete")

        # Phase 2: Train teacher model normally
        logger.info("Phase 2: Training teacher model...")
        self.cfg.discovery_interval = 10  # Enable discovery at epoch 15
        # self.cfg.t_matrix_mode = "throughout"  # Start recording T-matrix after discovery
        if self.cfg.t_matrix_mode == "late":
            self.cfg.t_matrix_start_epoch = 15 # Record transitions from epoch 15 onwards
        #for no discovery use: self.cfg.discovery_interval = 999, and erase the rest


        # Train for specified epochs
        metrics = self.train(train_loader, val_loader, test_loader)

        # Phase 3: Comprehensive discovery on trained model
        logger.info("Phase 3: Discovering clusters on trained model...")
        groups, labels, T, difficulty_profiles = self.discover_and_profile()

        # Phase 4: Save comprehensive recipe
        recipe_path = self.save_comprehensive_recipe(groups, labels, T, difficulty_profiles)

        # Ensure metrics has the expected structure
        if 'val_robust_acc' not in metrics:
            # Evaluate to get validation metrics
            val_metrics = self.evaluate(val_loader)
            metrics['val_robust_acc'] = [val_metrics['robust']]
            metrics['val_clean_acc'] = [val_metrics['clean']]

        return metrics

    def discover_and_profile(self):
        """Discover clusters and profile them comprehensively"""

        original_warmup = None
        if hasattr(self.clusterer, 'warmup_epochs'):
            original_warmup = self.clusterer.warmup_epochs
            self.clusterer.warmup_epochs = 0

        # Run discovery with multiple feature sources
        groups, labels, T_raw, Xs, Xg = self.comprehensive_discovery_pass()

        # Ensure feature_weights exists for recipe saving
        if not hasattr(self, 'feature_weights'):
            self.feature_weights = {
                'selected_features': ['stats', 'geom'],
                'feature_combination_mode': 'multi_view',
                'best_ordering': 'cyclical',
                'best_order_sequence': list(range(len(groups))) * 100
            }
        if not hasattr(self, 'discovered_best_order'):
            self.discovered_best_order = list(range(len(groups))) * 100

        if not hasattr(self, 'transition_scheduler'):
            self.transition_scheduler = TransitionScheduler(self.cfg.n_clusters, self.cfg)
            logger.warning("TransitionScheduler was not initialized, creating now")


        # VERIFY clusterer was fitted properly
        if self.clusterer.stats_mu is None or self.clusterer.geom_mu is None:
            logger.error("CRITICAL: Clusterer not fitted properly after discovery!")
            logger.info("Attempting to refit clusterer...")
            # Note: comprehensive_discovery_pass should have batch_ids_list
            # We need to get it from there or recompute
            batch_ids_list = []  # This would need to be obtained from comprehensive_discovery_pass
            self.clusterer.fit(Xs, Xg, batch_ids_list, epoch=999)

            if self.clusterer.stats_mu is None:
                raise RuntimeError("Failed to fit clusterer - cannot proceed with recipe creation")

        logger.info(f"Clusterer state confirmed: stats_mu shape={self.clusterer.stats_mu.shape if self.clusterer.stats_mu is not None else None}")

        # We want the T built during training
        if not hasattr(self, 'transition_scheduler'):
            self.transition_scheduler = TransitionScheduler(self.cfg.n_clusters, self.cfg)
            logger.warning("TransitionScheduler was not initialized, creating now")
            # Use T from discovery as initial estimate
            T = T_raw if T_raw is not None else torch.eye(len(groups), device=self.device)
        else:
            # Use T built during training (from epochs 15-30)
            T = self.transition_scheduler.compute_T()
            T = torch.tensor(T, device=self.device, dtype=torch.float32)

            # Log T matrix info
            logger.info(f"Using transition matrix from {self.cfg.t_matrix_mode} mode")
            logger.info(f"Total transitions recorded: {int(self.transition_scheduler.counts.sum())}")

            # Check if T is meaningful or just random
            T_entropy = -torch.sum(T * torch.log(T + 1e-8)) / (T.shape[0] * T.shape[1])
            logger.info(f"T matrix entropy: {T_entropy:.3f} ({'structured' if T_entropy < 0.8 else 'random-like'})")

        # Profile each cluster with various attack parameters
        difficulty_profiles = self.profile_clusters_comprehensively(groups)

        if original_warmup is not None:
            self.clusterer.warmup_epochs = original_warmup

        return groups, labels, T, difficulty_profiles

    def comprehensive_discovery_pass(self):
        """Enhanced discovery using multiple feature sources"""
        device = self.device

        # Create loader for discovery
        disc_loader = DataLoader(
            self.train_set,
            batch_size=self.cfg.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0
        )

        # Collect features from multiple sources
        all_features = {
            'ae': [],
            'model': [],
            'simclr': [],
        }

        self.model.eval()
        self.autoencoder.eval()

        # First pass: collect all features
        with torch.no_grad():
            for x, y in disc_loader:
                x, y = x.to(device), y.to(device)

                # 1. Autoencoder features
                z_ae = self.autoencoder.encode(x)
                all_features['ae'].append(z_ae.cpu())

                # 2. Model features
                f = self.model.relu(self.model.bn1(self.model.conv1(
                    Attacks.normalize(x, self.mean, self.std))))
                f = self.model.layer1(f)
                f = self.model.layer2(f)
                f = self.model.layer3(f)
                f = self.model.layer4(f)
                f = self.model.avgpool(f).view(f.size(0), -1)
                all_features['model'].append(f.cpu())

                # 3. SimCLR features (if available)
                if hasattr(self, 'simclr'):
                    z_simclr = self.simclr.extract_features(x)
                    all_features['simclr'].append(z_simclr.cpu())

        # Concatenate features
        combined_features = []
        for key in ['ae', 'model', 'simclr']:
            if all_features[key]:
                feats = torch.cat(all_features[key], dim=0)
                combined_features.append(feats)

        # Store feature configuration
        if len(combined_features) == 3:
            self.discovery_feature_mode = 'ae_model_simclr'
        elif len(combined_features) == 2:
            self.discovery_feature_mode = 'ae_model'
        else:
            self.discovery_feature_mode = 'ae_only'

        Z_combined = torch.cat(combined_features, dim=1) if len(combined_features) > 1 else combined_features[0]
        self.discovery_feature_dim = Z_combined.shape[1]
        self.geom_profiler.expected_dim = Z_combined.shape[1]
        logger.info(f"Discovery using {self.discovery_feature_mode} with dim={self.discovery_feature_dim}")

        # Fit geometry components
        self.geom.fit_codebook(Z_combined)
        self.geom.fit_prototypes(Z_combined)

        # Second pass: compute batch-level statistics and geometry
        Xs_list, Xg_list, batch_ids_list = [], [], []

        for b_idx, (x, y) in enumerate(disc_loader):
            x, y = x.to(device), y.to(device)

            # Recreate combined features for this batch
            with torch.no_grad():
                features_to_concat = []
                z_ae = self.autoencoder.encode(x)
                features_to_concat.append(z_ae)

                f = self.model.relu(self.model.bn1(self.model.conv1(
                    Attacks.normalize(x, self.mean, self.std))))
                f = self.model.layer1(f)
                f = self.model.layer2(f)
                f = self.model.layer3(f)
                f = self.model.layer4(f)
                f = self.model.avgpool(f).view(f.size(0), -1)
                features_to_concat.append(f)

                if hasattr(self, 'simclr'):
                    z_simclr = self.simclr.extract_features(x)
                    features_to_concat.append(z_simclr)

                z = torch.cat(features_to_concat, dim=1) if len(features_to_concat) > 1 else features_to_concat[0]

            # Geometry features
            g_vec = self.geom.batch_geometry(z, x_batch=x, model=self.model)
            Xg_list.append(g_vec)

            # Statistical features
            s_vec = self.stats.extract_batch_stats(self.model, x, y, self.cfg.epsilon)
            Xs_list.append(s_vec)

            # Track batch indices
            batch_ids = list(range(b_idx * x.size(0), b_idx * x.size(0) + x.size(0)))
            batch_ids_list.append(batch_ids)

        Xs = torch.stack(Xs_list, dim=0)
        Xg = torch.stack(Xg_list, dim=0)

        # Multi-view clustering - THIS IS CRITICAL
        self.clusterer.fit(Xs, Xg, batch_ids_list, epoch=999)
        labels = self.clusterer.final_labels
        T = self.clusterer.get_transition()

        # Build groups
        Kf = int(labels.max().item()) + 1
        groups = {k: [] for k in range(Kf)}
        for b, lab in enumerate(labels.tolist()):
            groups[lab].extend(batch_ids_list[b])

        return groups, labels, T, Xs, Xg

    def profile_clusters_comprehensively(self, groups):
        """Profile each cluster with multiple attack configurations"""

        difficulty_profiles = {}

        for cluster_id, indices in groups.items():
            logger.info(f"Profiling cluster {cluster_id} ({len(indices)} samples)...")

            # Sample subset for efficiency
            sample_size = min(1000, len(indices))
            sampled_indices = np.random.choice(indices, sample_size, replace=False)

            profile = {
                'pgd_steps_to_fool': {},
                'margin_distribution': [],
                'gradient_stats': {},
                'loss_landscape': {},
                'attack_transferability': {},
                'asr': 0.5,  # Default values for missing keys
                'margin': 0.5,
                'grad_complexity': 1.0
            }

            # Test with multiple PGD configurations
            for pgd_steps in [2,3, 5, 7, 10, 15, 20, 30]:
                asr_list = []
                margin_list = []
                grad_norm_list = []

                for i in range(0, len(sampled_indices), 64):
                    batch_indices = sampled_indices[i:i+64]
                    x = torch.stack([self.train_set[idx][0] for idx in batch_indices])
                    y = torch.tensor([self.train_set[idx][1] for idx in batch_indices])
                    x, y = x.to(self.device), y.to(self.device)

                    # Attack with current step count
                    x_adv, _ = Attacks.pgd(
                        self.model, x, y,
                        self.cfg.epsilon, self.cfg.pgd_step_size, pgd_steps,
                        self.mean, self.std,
                        use_adv_bn=True, early_stop=False
                    )

                    with torch.no_grad():
                        logits_clean = self.model(Attacks.normalize(x, self.mean, self.std))
                        logits_adv = self.model(Attacks.normalize(x_adv, self.mean, self.std))

                        # Attack success rate
                        asr = (logits_adv.argmax(1) != y).float().mean()
                        asr_list.append(asr.item())

                        # Margins
                        correct_logits = logits_clean.gather(1, y.unsqueeze(1))
                        wrong_logits = logits_clean.clone()
                        wrong_logits.scatter_(1, y.unsqueeze(1), float('-inf'))
                        margins = (correct_logits.squeeze() - wrong_logits.max(1)[0])
                        margin_list.extend(margins.cpu().tolist())

                    # Gradient statistics
                    x_adv.requires_grad_(True)
                    loss = F.cross_entropy(
                        self.model(Attacks.normalize(x_adv, self.mean, self.std)), y
                    )
                    grad = torch.autograd.grad(loss, x_adv)[0]
                    grad_norm_list.append(grad.norm(p=2, dim=(1,2,3)).mean().item())

                profile['pgd_steps_to_fool'][pgd_steps] = {
                    'asr': np.mean(asr_list),
                    'grad_norm': np.mean(grad_norm_list)
                }

            # Additional profiling: test with different epsilon values
            for eps_ratio in [0.5, 0.75, 1.0, 1.25, 1.5]:
                test_eps = self.cfg.epsilon * eps_ratio
                # ... test with modified epsilon

            # Compute overall difficulty score
            profile['overall_difficulty'] = self._compute_difficulty_score(profile)

            difficulty_profiles[cluster_id] = profile

        return difficulty_profiles

    def _compute_difficulty_score(self, profile):
        """Compute single difficulty score from profile"""

        # Weighted combination of factors
        asr_10 = profile['pgd_steps_to_fool'].get(10, {}).get('asr', 0.5)
        grad_complexity = np.mean([v['grad_norm'] for v in profile['pgd_steps_to_fool'].values()])

        # Higher ASR and lower gradient norms = easier to attack = higher difficulty for training
        difficulty = asr_10 * 0.5 + (1 / (grad_complexity + 1)) * 0.5

        return difficulty

    def _extract_high_reward_paths(self, top_k=5):
        """Extract the highest reward transition paths discovered"""
        if not hasattr(self, 'transition_rewards'):
            return []

        path_scores = []
        for (i, j), rewards in self.transition_rewards.items():
            if len(rewards) >= 3:  # Need enough samples
                avg_reward = np.mean(rewards[-10:])  # Recent average
                path_scores.append({
                    'from': int(i),
                    'to': int(j),
                    'avg_reward': float(avg_reward),
                    'count': len(rewards)
                })

        # Sort by average reward
        path_scores.sort(key=lambda x: x['avg_reward'], reverse=True)
        return path_scores[:top_k]

    def _extract_discovered_paths(self):
        """Extract the highest-value discovered conceptual paths"""
        if not hasattr(self, 'transition_scheduler'):
            return {}

        paths = {}
        T = self.transition_scheduler.compute_T(normalize=False)

        # Find strong paths (high transition value)
        for i in range(T.shape[0]):
            for j in range(T.shape[1]):
                if T[i, j] > T.mean() * 2:  # Significantly above average
                    if i not in paths:
                        paths[i] = []
                    paths[i].append({
                        'to': j,
                        'strength': float(T[i, j]),
                        'avg_improvement': np.mean(self.transition_scheduler.transition_rewards.get((i,j), [0]))
                    })

        return paths

    def _extract_conceptual_paths(self):
        """Extract discovered conceptual prerequisite paths"""
        if not hasattr(self, 'transition_scheduler'):
            return {}

        paths = {}
        T = self.transition_scheduler.compute_T(normalize=False, use_rewards=True)

        # Find strong conceptual connections
        for i in range(T.shape[0]):
            strong_transitions = []
            for j in range(T.shape[1]):
                if T[i, j] > T.mean() * 1.5:  # Above average transitions
                    rewards = self.transition_scheduler.transition_rewards.get((i,j), [])
                    strong_transitions.append({
                        'to': int(j),
                        'strength': float(T[i, j]),
                        'avg_improvement': float(np.mean(rewards[-10:])) if rewards else 0.0
                    })
            if strong_transitions:
                paths[int(i)] = strong_transitions

        return paths

    def _apply_learned_feature_weights(self, x):
        """Apply teacher's learned feature weights to extract features"""
        if not hasattr(self, 'feature_weights') or 'weights' not in self.feature_weights:
            # Fallback to autoencoder
            return self.autoencoder.encode(x)

        weights = self.feature_weights['weights']
        combined_features = []

        with torch.no_grad():
            # Apply each weighted feature
            if weights.get('stats', 0) > 0:
                # Stats need batch-level, handle separately
                pass  # Will be handled at batch level

            if weights.get('geom', 0) > 0:
                z_ae = self.autoencoder.encode(x)
                combined_features.append(z_ae * weights['geom'])

            if weights.get('confidence', 0) > 0:
                logits = self.model(Attacks.normalize(x, self.mean, self.std))
                probs = F.softmax(logits, dim=1)
                conf_feat = probs.max(dim=1)[0].unsqueeze(1)
                combined_features.append(conf_feat * weights['confidence'])

            if weights.get('activations', 0) > 0:
                h = self.model.layer3(self.model.layer2(self.model.layer1(
                    self.model.relu(self.model.bn1(self.model.conv1(
                        Attacks.normalize(x, self.mean, self.std)))))))
                h_pooled = h.mean(dim=[2, 3])
                combined_features.append(h_pooled * weights['activations'])

        if combined_features:
            return torch.cat(combined_features, dim=1)
        else:
            return self.autoencoder.encode(x)  # Fallback


    def save_comprehensive_recipe(self, groups, labels, T, difficulty_profiles):
        """Save complete teacher recipe with all information"""

        # Convert numpy scalars to Python floats
        for cid, profile in difficulty_profiles.items():
            if 'overall_difficulty' in profile:
                profile['overall_difficulty'] = float(profile['overall_difficulty'])

        # Get feature configuration
        feature_config = {
            'mode': getattr(self, 'discovery_feature_mode', 'ae_model'),
            'dim': getattr(self, 'discovery_feature_dim', None),
            'ae_dim': self.cfg.ae_latent_dim,
            'model_dim': 512,
            'use_simclr': hasattr(self, 'simclr'),
            'simclr_dim': 128 if hasattr(self, 'simclr') else None,
        }

        # Save clusterer state to file
        clusterer_state_path = self.cfg.experiment_dir / 'clusterer_models.pkl'

        # CRITICAL CHECK: Ensure we have scalers before saving
        if self.clusterer.stats_mu is None or self.clusterer.geom_mu is None:
            logger.error("Cannot save recipe - clusterer scalers are missing!")
            logger.error(f"stats_mu: {self.clusterer.stats_mu is not None}, geom_mu: {self.clusterer.geom_mu is not None}")
            raise RuntimeError("Clusterer not properly fitted - cannot save recipe")

        self.clusterer.save_state(str(clusterer_state_path))
        logger.info(f"Saved clusterer state to {clusterer_state_path}")

        # Prepare complete clusterer state for embedding in recipe
        full_clusterer_state = {
            'stats_mu': self.clusterer.stats_mu.cpu().numpy(),
            'stats_sigma': self.clusterer.stats_sigma.cpu().numpy(),
            'geom_mu': self.clusterer.geom_mu.cpu().numpy(),
            'geom_sigma': self.clusterer.geom_sigma.cpu().numpy(),
            'pair_to_final': self.clusterer.pair_to_final,
            'ks': self.clusterer.ks.cpu().numpy() if hasattr(self.clusterer, 'ks') and self.clusterer.ks is not None else None,
            'kg': self.clusterer.kg.cpu().numpy() if hasattr(self.clusterer, 'kg') and self.clusterer.kg is not None else None,
            'final_labels': self.clusterer.final_labels.cpu().numpy() if self.clusterer.final_labels is not None else None,
        }

        import hashlib
        feature_hash = hashlib.md5()
        feature_hash.update(str(self.autoencoder.state_dict().keys()).encode())
        feature_hash.update(str([p.shape for p in self.model.parameters()][:5]).encode())
        if hasattr(self, 'simclr'):
            feature_hash.update(str(self.simclr.state_dict().keys()).encode())

        conceptual_paths = self._extract_conceptual_paths()
        discovered_ordering_info = {}
        if hasattr(self, 'discovered_best_order') and self.discovered_best_order:
            discovered_ordering_info = {
                'best_order_name': self.feature_weights.get('best_ordering', 'unknown') if hasattr(self, 'feature_weights') else 'unknown',
                'best_sequence': self.discovered_best_order,
                'ordering_scores': self.feature_weights.get('ordering_scores', {}) if hasattr(self, 'feature_weights') else {},
                'tested_at_epoch': self.current_epoch
            }
        refinements = self.refine_recipe_with_performance()

        recipe = {
            'groups': groups,
            'discovered_ordering': discovered_ordering_info,
            'feature_weights': {
                'weights': best_weights if 'best_weights' in locals() else self.feature_weights,
                'active_features': best_active if 'best_active' in locals() else [],
                'selection_score': best_score if 'best_score' in locals() else 0.0,
                'all_tested_combinations': len(weight_combinations) if 'weight_combinations' in locals() else 0,
            } if hasattr(self, 'feature_weights') else {},
            'labels': labels.cpu() if torch.is_tensor(labels) else labels,
            'transition_matrix': T.cpu() if torch.is_tensor(T) else T,
            'difficulty_profiles': difficulty_profiles,
            'refinements': refinements,
            'feature_config': feature_config,
            'proven_transitions': {  # ADD THIS BLOCK
                k: v for k, v in self.transition_qualities.items()
                if len(v) >= 5 and np.mean([q.get('success', False) for q in v]) > 0.6
            },
            'learned_metric_weights': self.discovered_metric_weights if hasattr(self, 'discovered_metric_weights') else None,
            'feature_weights_detailed': {
                'feature_weights': self.feature_weights,
                'metric_evaluation_weights': self.discovered_metric_weights if hasattr(self, 'discovered_metric_weights') else None,
                'n_weight_combinations_tested': getattr(self, 'n_combinations_tested', 0),
                'learning_correlation_scores': self.metric_score_history[-10:] if hasattr(self, 'metric_score_history') else []
            },
            'ae_state': self.autoencoder.state_dict(),
            'codebook': self.geom.codebook.cpu() if self.geom.codebook is not None else None,
            'prototypes': self.geom.prototypes.cpu() if self.geom.prototypes is not None else None,
            'codebook_dim': self.geom.codebook.shape[1] if self.geom.codebook is not None else None,
            'prototypes_dim': self.geom.prototypes.shape[1] if self.geom.prototypes is not None else None,
            'stats_clusterer': self.clusterer.stats_clusterer.get_params(),
            'geom_clusterer': self.clusterer.geom_clusterer.get_params(),
            'clusterer_state_path': str(clusterer_state_path),
            'clusterer_state': full_clusterer_state,
            'simclr_state': self.simclr.state_dict() if hasattr(self, 'simclr') else None,
            'teacher_metrics': {
                'clean_acc': self.evaluate(self.val_loader)['clean'],
                'robust_acc': self.evaluate(self.val_loader)['robust']
            },
            'config': asdict(self.cfg),
            'feature_compatibility_hash': feature_hash.hexdigest()[:16],
            'recipe_version': '2.1',
            'creation_time': datetime.now().isoformat(),
            't_matrix_metadata': {
                'mode': getattr(self.cfg, 't_matrix_mode', 'throughout'),
                'total_transitions': int(self.transition_scheduler.counts.sum()) if hasattr(self, 'transition_scheduler') else 0,
                'transition_rewards': dict(self.transition_rewards) if hasattr(self, 'transition_rewards') else {},
                'discovered_paths': self._extract_high_reward_paths() if hasattr(self, 'transition_rewards') else [],
                'conceptual_paths': self._extract_conceptual_paths()  # ADD THIS
            },
            'discovered_ordering': {
                'best_order_name': self.feature_weights.get('best_ordering', 'unknown'),
                'best_sequence': getattr(self, 'discovered_best_order', [])[:100],  # First 100
                'ordering_scores': self.feature_weights.get('ordering_scores', {}),
                'tested_at_epoch': self.current_epoch
            },
        }

        recipe_path = self.cfg.experiment_dir / 'teacher_recipe_comprehensive.pkl'
        torch.save(recipe, recipe_path, pickle_protocol=4)
        logger.info(f"Saved comprehensive recipe (v2.0) to {recipe_path}")
        logger.info(f"  Feature mode: {feature_config['mode']}, dim: {feature_config['dim']}")
        logger.info(f"  Groups: {len(groups)}, Difficulty profiles: {len(difficulty_profiles)}")

        return recipe_path



    def extract_teacher_recipe(self):
            """Extract complete recipe from trained teacher"""
            logger.info("Extracting teacher recipe...")

            # Run discovery on final teacher
            groups, labels, T, Xs, Xg = self.discovery_pass()

            # Calculate multi-dimensional difficulty profiles
            difficulty_profiles = {}
            for cid in range(len(groups)):
                indices = groups[cid][:1000]  # Sample for efficiency

                # Run comprehensive evaluation
                asr_list = []
                grad_norm_list = []
                margin_list = []
                steps_list = []

                for i in range(0, len(indices), self.cfg.batch_size):
                    batch_indices = indices[i:i+self.cfg.batch_size]
                    x = torch.stack([self.train_set[idx][0] for idx in batch_indices])
                    y = torch.tensor([self.train_set[idx][1] for idx in batch_indices])
                    x, y = x.to(self.device), y.to(self.device)

                    # Test with PGD
                    x_adv, steps = Attacks.pgd(
                        self.model, x, y,
                        self.cfg.epsilon, self.cfg.pgd_step_size, 7,
                        self.mean, self.std,
                        use_adv_bn=True, early_stop=True
                    )

                    with torch.no_grad():
                        logits_clean = self.model(Attacks.normalize(x, self.mean, self.std))
                        logits_adv = self.model(Attacks.normalize(x_adv, self.mean, self.std))

                        # ASR
                        asr = (logits_adv.argmax(1) != y).float().mean()
                        asr_list.append(asr.item())

                        # Margins
                        correct_logits = logits_clean.gather(1, y.unsqueeze(1))
                        margins = correct_logits.squeeze()
                        margin_list.append(margins.mean().item())

                        # Steps used
                        steps_list.append(steps)

                    # Gradient norm (need grad)
                    x_adv.requires_grad_(True)
                    loss = F.cross_entropy(self.model(Attacks.normalize(x_adv, self.mean, self.std)), y)
                    grad = torch.autograd.grad(loss, x_adv)[0]
                    grad_norm_list.append(grad.norm(p=2, dim=(1,2,3)).mean().item())

                # Multi-dimensional profile
                difficulty_profiles[cid] = {
                    'asr': np.mean(asr_list),
                    'grad_complexity': np.mean(grad_norm_list),
                    'margin': np.mean(margin_list),
                    'steps_needed': np.mean(steps_list),
                    'overall': np.mean(asr_list) * 0.4 + (1 - np.mean(margin_list)) * 0.3 + np.mean(steps_list)/10 * 0.3
                }

            # Identify uncertain samples (high view disagreement)
            uncertain_indices = []
            if hasattr(self.clusterer, 'ks') and hasattr(self.clusterer, 'kg'):
                disagreement = (self.clusterer.ks != self.clusterer.kg)
                uncertain_idx = torch.where(disagreement)[0]
                for idx in uncertain_idx[:int(len(labels) * self.cfg.uncertainty_ratio)]:
                    uncertain_indices.extend(groups[labels[idx].item()])

            # Package recipe
            recipe = {
                'stats_mu': self.clusterer.stats_mu.cpu(),
                'stats_sigma': self.clusterer.stats_sigma.cpu(),
                'geom_mu': self.clusterer.geom_mu.cpu(),
                'geom_sigma': self.clusterer.geom_sigma.cpu(),
                'stats_centroids': self.clusterer.stats_clusterer.get_params(),
                'geom_centroids': self.clusterer.geom_clusterer.get_params(),
                'pair_to_final': self.clusterer.pair_to_final,
                'T': T.cpu(),
                'difficulty_profiles': difficulty_profiles,
                'groups': groups,
                'uncertain_indices': uncertain_indices,
                'teacher_robust_acc': self.evaluate(self.val_loader)['robust'],
                'config': asdict(self.cfg)
            }

            recipe_path = self.cfg.experiment_dir / 'teacher_recipe.pkl'
            torch.save(recipe, recipe_path)
            logger.info(f"Saved teacher recipe to {recipe_path}")
            return recipe_path

    def _initialize_student_adaptations(self):
        """Initialize student-specific optimizations"""
        # Progressive difficulty scaling - start more conservative
        self.student_difficulty_scale = 1.1  # Make clusters seem 10% harder initially

        # Performance tracking for online adaptation
        self.student_cluster_performance = defaultdict(list)
        self._fallback_cluster_idx = 0

        logger.info("Student adaptations initialized")

    def _validate_cluster_prediction_works(self):
        """Test cluster prediction with small sample"""
        try:
            # Test with small batch from training data
            test_loader = DataLoader(self.train_set, batch_size=16, shuffle=True, num_workers=0)

            for x, y in test_loader:
                x, y = x.to(self.device), y.to(self.device)

                # Test cluster prediction
                z = self._get_features_for_clustering(x)
                stats_feat = self.stats_profiler.extract_batch_stats(self.model, x, y, self.cfg.epsilon)
                geom_feat = self.geom_profiler.batch_geometry(z, x_batch=x, model=self.model)

                cluster_id = self.clusterer.predict_final_labels(
                    stats_feat.unsqueeze(0),
                    geom_feat.unsqueeze(0)
                )[0].item()

                if 0 <= cluster_id < self.cfg.n_clusters:
                    logger.info(f"Recipe validation successful - predicted cluster {cluster_id}")
                    return True
                break

        except Exception as e:
            logger.warning(f"Cluster prediction validation failed: {e}")

        return False

    def _predict_cluster_robust(self, x, y):
        """Robust cluster prediction with multiple fallback strategies"""
        # Strategy 1: Primary clusterer prediction
        try:
            with torch.no_grad():
                # Extract features using teacher's method
                z = self._get_features_for_clustering(x)
                stats_feat = self.stats_profiler.extract_batch_stats(self.model, x, y, self.cfg.epsilon)
                geom_feat = self.geom_profiler.batch_geometry(z, x_batch=x, model=self.model)

                cluster_id = self.clusterer.predict_final_labels(
                    stats_feat.unsqueeze(0),
                    geom_feat.unsqueeze(0)
                )[0].item()

                if 0 <= cluster_id < self.cfg.n_clusters:
                    return cluster_id

        except Exception as e:
            logger.debug(f"Primary cluster prediction failed: {e}")

        # Strategy 2: Difficulty-weighted sampling
        try:
            if hasattr(self, 'difficulty_profiles') and self.difficulty_profiles:
                difficulties = []
                for cid in range(self.cfg.n_clusters):
                    if cid in self.difficulty_profiles:
                        diff = self.difficulty_profiles[cid].get('overall_difficulty', 0.5)
                        difficulties.append(diff)
                    else:
                        difficulties.append(0.5)

                # Sample harder clusters more frequently early in training
                epoch_factor = max(0.1, 1.0 - self.current_epoch / self.cfg.epochs)
                probs = np.array(difficulties)
                probs = probs ** (1.0 + epoch_factor)  # Bias toward harder clusters early
                probs = probs / (probs.sum() + 1e-8)

                return np.random.choice(self.cfg.n_clusters, p=probs)

        except Exception as e:
            logger.debug(f"Difficulty-weighted sampling failed: {e}")

        # Strategy 3: Round-robin fallback
        cluster_id = self._fallback_cluster_idx % self.cfg.n_clusters
        self._fallback_cluster_idx += 1
        return cluster_id

    def _validate_recipe_application(self):
        """Test that recipe components work correctly"""
        try:
            # Test cluster prediction on dummy batch
            dummy_x = torch.randn(4, 3, 32, 32).to(self.device)
            dummy_y = torch.randint(0, 10, (4,)).to(self.device)

            cluster_id = self._predict_cluster_robust(dummy_x, dummy_y)
            logger.info(f"Recipe validation: predicted cluster {cluster_id}")

            if 0 <= cluster_id < self.cfg.n_clusters:
                return True
            else:
                logger.warning(f"Invalid cluster prediction: {cluster_id}")
                return False

        except Exception as e:
            logger.error(f"Recipe validation failed: {e}")
            return False

    def _get_features_for_clustering(self, x):
        """Extract features using teacher's discovered optimal feature combination"""
        with torch.no_grad():
            # Check if we have learned feature weights from discovery
            if hasattr(self, 'feature_weights') and 'selected_features' in self.feature_weights:
                selected = self.feature_weights['selected_features']
                weights = self.feature_weights.get('feature_scores', {})

                combined_features = []

                if 'stats' in selected or 'geom' in selected:
                    # Need batch-level features, return a marker
                    return None  # Signal that batch-level extraction is needed

                if 'confidence' in selected:
                    logits = self.model(Attacks.normalize(x, self.mean, self.std))
                    probs = F.softmax(logits, dim=1)
                    conf_feat = torch.stack([
                        probs.max(dim=1)[0],
                        -(probs * probs.log()).sum(dim=1)
                    ], dim=1)
                    combined_features.append(conf_feat * weights.get('confidence', 1.0))

                if 'activations' in selected:
                    h = self.model.layer3(self.model.layer2(self.model.layer1(
                        self.model.relu(self.model.bn1(self.model.conv1(
                            Attacks.normalize(x, self.mean, self.std)))))))
                    h_pooled = h.mean(dim=[2, 3])
                    combined_features.append(h_pooled * weights.get('activations', 1.0))

                # Always include AE features as base
                z_ae = self.autoencoder.encode(x)
                combined_features.append(z_ae * weights.get('ae', 0.5))

                if combined_features:
                    return torch.cat(combined_features, dim=1)
            if hasattr(self, 'discovery_feature_mode'):
                if self.discovery_feature_mode == 'ae_model_simclr':
                    features = []
                    features.append(self.autoencoder.encode(x))
                    features.append(self._get_model_features(x))
                    if hasattr(self, 'simclr'):
                        features.append(self.simclr.extract_features(x))
                    else:
                        features.append(torch.zeros(x.size(0), 128, device=x.device))
                    return torch.cat(features, dim=1)

                elif self.discovery_feature_mode == 'ae_model':
                    z_ae = self.autoencoder.encode(x)
                    z_model = self._get_model_features(x)
                    return torch.cat([z_ae, z_model], dim=1)
                else:
                    return self.autoencoder.encode(x)
            else:
                return self.autoencoder.encode(x)



    def update_student_performance(self):
        """Update student difficulty scaling based on recent performance"""
        if not hasattr(self, 'is_student_mode') or not self.is_student_mode:
            return

        if not hasattr(self, 'student_cluster_performance'):
            return

        total_efficiency = []

        for cluster_id, performance_list in self.student_cluster_performance.items():
            if len(performance_list) < 5:
                continue

            recent_perf = list(performance_list[-10:]) # Last 10 samples
            avg_steps = np.mean([p.get('steps_used', self.cfg.pgd_steps) for p in recent_perf])

            # Simple efficiency metric (lower steps = higher efficiency)
            efficiency = (self.cfg.pgd_steps - avg_steps) / self.cfg.pgd_steps
            total_efficiency.append(efficiency)

        if total_efficiency:
            overall_efficiency = np.mean(total_efficiency)
            target_efficiency = 0.3  # Target 30% step savings

            # Adapt difficulty scaling
            if overall_efficiency > target_efficiency * 1.2:
                # Too easy, make harder
                self.student_difficulty_scale = min(1.5, self.student_difficulty_scale * 1.05)
            elif overall_efficiency < target_efficiency * 0.8:
                # Too hard, make easier
                self.student_difficulty_scale = max(0.9, self.student_difficulty_scale * 0.95)

            logger.info(f"🎓 Student efficiency: {overall_efficiency:.3f}, difficulty scale: {self.student_difficulty_scale:.3f}")

    def load_teacher_recipe(self, recipe_path):
        """Load and apply teacher recipe with validation"""
        logger.info(f"Loading teacher recipe from {recipe_path}")

        # CRITICAL: Set student mode flags FIRST
        self.is_student_mode = True
        self.skip_discovery = True
        logger.info("Student mode activated - will use teacher's knowledge immediately")

        # Load recipe
        recipe = torch.load(recipe_path, weights_only=False, map_location=self.device)
        self._loaded_recipe = recipe

        # Validate recipe completeness
        required_keys = ['difficulty_profiles', 'clusterer_state', 'groups', 'transition_matrix']
        missing_keys = [k for k in required_keys if k not in recipe or recipe[k] is None]
        if missing_keys:
            raise ValueError(f"Incomplete recipe missing: {missing_keys}")

        # Check recipe version
        recipe_version = recipe.get('recipe_version', '1.0')
        logger.info(f"Loading recipe version {recipe_version}")

        feature_config = recipe.get('feature_config', {})

        # Check feature compatibility if hash exists
        if 'feature_compatibility_hash' in recipe:
            import hashlib
            current_hash = hashlib.md5()
            current_hash.update(str(self.autoencoder.state_dict().keys()).encode())
            current_hash.update(str([p.shape for p in self.model.parameters()][:5]).encode())
            if hasattr(self, 'simclr'):
                current_hash.update(str(self.simclr.state_dict().keys()).encode())

            if current_hash.hexdigest()[:16] != recipe['feature_compatibility_hash']:
                logger.warning("Feature architecture mismatch detected - recipe may be incompatible")
                logger.warning(f"Expected: {recipe['feature_compatibility_hash']}, Got: {current_hash.hexdigest()[:16]}")

        # Load feature configuration
        if 'feature_weights' in recipe:
            self.feature_weights = recipe['feature_weights']
            self.selected_features = self.feature_weights.get('selected_features', [])
            self.feature_scores = self.feature_weights.get('feature_scores', {})
            logger.info(f"Loaded feature discovery: using {self.selected_features} with scores {self.feature_scores}")

            # Validate we can reproduce the features
            if feature_config['use_simclr'] and not hasattr(self, 'simclr'):
                logger.warning("Recipe uses SimCLR but it's not initialized!")
                if recipe.get('simclr_state'):
                    logger.info("Loading SimCLR from recipe...")
                    self.simclr = SimCLREncoder(base_model='resnet18').to(self.device)
                    self.simclr.load_state_dict(recipe['simclr_state'])

        # Load autoencoder state
        if 'ae_state' in recipe:
            self.autoencoder.load_state_dict(recipe['ae_state'])
            logger.info("Loaded autoencoder state from recipe")

        # Load geometry profiler state
        if recipe.get('codebook') is not None:
            self.geom.codebook = recipe['codebook'].to(self.device)
            self.geom.prototypes = recipe['prototypes'].to(self.device)
            expected_dim = recipe.get('codebook_dim')
            if expected_dim:
                self.geom.expected_dim = expected_dim
            logger.info(f"Loaded codebook and prototypes (dim={expected_dim})")

        # Load clusterer state
        loaded_from_file = False
        if 'clusterer_state_path' in recipe:
            clusterer_path = recipe['clusterer_state_path']
            if os.path.exists(clusterer_path):
                try:
                    self.clusterer.load_state(clusterer_path)
                    logger.info(f"Loaded complete clusterer state from {clusterer_path}")
                    loaded_from_file = True


                    # Move clusterer scalers to correct device with safety checks
                    for attr_name in ['stats_mu', 'stats_sigma', 'geom_mu', 'geom_sigma']:
                        attr_val = getattr(self.clusterer, attr_name, None)
                        if attr_val is not None:
                            if hasattr(attr_val, 'to'):
                                setattr(self.clusterer, attr_name, attr_val.to(self.device))
                            else:
                                # Convert numpy to tensor if needed
                                setattr(self.clusterer, attr_name, torch.tensor(attr_val).to(self.device))

                except Exception as e:
                    logger.warning(f"Failed to load from file: {e}, will use embedded state")

        # Always load the embedded state to ensure we have all components
        if (not loaded_from_file) and ('clusterer_state' in recipe):
            # Fallback for old recipes - reconstruct clusterer
            cs = recipe['clusterer_state']

            # Initialize fresh clusterers
            from sklearn.cluster import MiniBatchKMeans
            self.clusterer.stats_clusterer = MiniBatchKMeansClusterer(n_clusters=self.cfg.K_stats)
            self.clusterer.geom_clusterer = MiniBatchKMeansClusterer(n_clusters=self.cfg.K_geom)

            # Load scalers and mappings
            self.clusterer.stats_mu = torch.from_numpy(cs['stats_mu']).to(self.device) if cs['stats_mu'] is not None else None
            self.clusterer.stats_sigma = torch.from_numpy(cs['stats_sigma']).to(self.device) if cs['stats_sigma'] is not None else None
            self.clusterer.geom_mu = torch.from_numpy(cs['geom_mu']).to(self.device) if cs['geom_mu'] is not None else None
            self.clusterer.geom_sigma = torch.from_numpy(cs['geom_sigma']).to(self.device) if cs['geom_sigma'] is not None else None
            self.clusterer.pair_to_final = cs['pair_to_final']

            # Load additional state if available
            if 'ks' in cs and cs['ks'] is not None:
                self.clusterer.ks = torch.from_numpy(cs['ks']).to(self.device).long()
            if 'kg' in cs and cs['kg'] is not None:
                self.clusterer.kg = torch.from_numpy(cs['kg']).to(self.device).long()
            if 'final_labels' in cs and cs['final_labels'] is not None:
                self.clusterer.final_labels = torch.from_numpy(cs['final_labels']).to(self.device).long()

            logger.warning("Loaded partial clusterer state from old recipe format")
        elif not loaded_from_file:
            raise ValueError("No clusterer state found in recipe")

        # Set groups and difficulties
        self.groups = recipe['groups']
        self.T = recipe['transition_matrix'].to(self.device) if torch.is_tensor(recipe['transition_matrix']) else torch.tensor(recipe['transition_matrix']).to(self.device)
        self.difficulty_profiles = recipe['difficulty_profiles']

        # ADD THIS CHECK:
        if not hasattr(self, 'train_loader'):
            logger.warning("train_loader not set yet, deferring group validation")
        else:
            if self.groups:
                # Check if groups reference valid indices
                max_idx_in_groups = max(max(indices) for indices in self.groups.values() if indices)
                dataset_size = len(self.train_loader.dataset)

                if max_idx_in_groups >= dataset_size:
                    logger.warning(f"Groups reference indices up to {max_idx_in_groups} but dataset has {dataset_size} samples")
                    logger.info("Re-applying recipe to current dataset...")
                    # Groups don't match current dataset, need to recompute
                    self._apply_recipe_to_dataset(self.train_loader.dataset)
            else:
                logger.warning("No groups in recipe, applying to dataset...")
                self._apply_recipe_to_dataset(self.train_loader.dataset)

        # Initialize cluster difficulties from profiles
        for cid, profile in self.difficulty_profiles.items():
            if isinstance(profile, dict) and 'overall_difficulty' in profile:
                self.cluster_difficulties[cid] = profile['overall_difficulty']
            elif isinstance(profile, dict) and 'overall' in profile:
                self.cluster_difficulties[cid] = profile['overall']
            else:
                self.cluster_difficulties[cid] = 0.5  # Default

        logger.info(f"Loaded recipe with {len(self.groups)} clusters")
        logger.info(f"Teacher robust accuracy: {recipe.get('teacher_metrics', {}).get('robust_acc', 'N/A')}")
        logger.info(f"Cluster difficulties: {self.cluster_difficulties}")

        if 'feature_weights' in recipe:
            self.discovered_feature_weights = recipe['feature_weights']
            self.discovery_feature_mode = recipe['feature_weights'].get('feature_combination_mode', 'ae_model')
            logger.info(f"Loaded feature weights with mode: {self.discovery_feature_mode}")

        if 'discovered_ordering' in recipe:
            self.discovered_best_order = recipe['discovered_ordering'].get('best_sequence', [])
            if self.discovered_best_order:
                logger.info(f"Loaded discovered ordering: {recipe['discovered_ordering'].get('best_order_name', 'unknown')}")

        # CRITICAL: Ensure model is ready for feature extraction
        self.model.eval()
        self.autoencoder.eval()
        if hasattr(self, 'simclr'):
            self.simclr.eval()

        # CRITICAL: Initialize student-specific adaptations
        self._initialize_student_adaptations()

        # CRITICAL: Pre-validate cluster prediction works
        test_success = self._validate_cluster_prediction_works()
        self._validate_recipe_application()
        if not test_success:
            logger.warning("Recipe cluster prediction validation failed - using robust fallbacks")


    def train_with_recipe(self, train_loader, val_loader, test_loader):
        """Modified training using teacher recipe"""

        # First validate the transfer works
        if not self.validate_recipe_transfer(train_loader):
            raise ValueError("Recipe transfer validation failed - cannot proceed with student training")

        # Skip discovery - use recipe
        self.cfg.discovery_interval = 999

        # Apply recipe to dataset to create groups
        if not hasattr(self, 'groups') or self.groups is None:
            self._apply_recipe_to_dataset(train_loader.dataset)

        # Standard training but with recipe-guided curriculum
        return self.train(train_loader, val_loader, test_loader)

    def _apply_recipe_to_dataset(self, dataset):
        """Apply teacher's clustering to student's dataset"""
        # Create a loader for feature extraction
        loader = DataLoader(dataset, batch_size=self.cfg.batch_size, shuffle=False)

        all_indices = []
        all_labels = []

        # Handle Subset vs regular dataset
        is_subset = hasattr(dataset, 'indices')

        for batch_idx, (x, y) in enumerate(loader):
            x = x.to(self.device)

            # Extract features using teacher's feature mode
            with torch.no_grad():
                # Get features based on teacher's mode
                if self.discovery_feature_mode == 'ae_model_simclr':
                    z_ae = self.autoencoder.encode(x)
                    z_model = self._get_model_features(x)
                    z_simclr = self.simclr.extract_features(x) if hasattr(self, 'simclr') else torch.zeros(x.size(0), 128, device=x.device)
                    z = torch.cat([z_ae, z_model, z_simclr], dim=1)
                elif self.discovery_feature_mode == 'ae_model':
                    z_ae = self.autoencoder.encode(x)
                    z_model = self._get_model_features(x)
                    z = torch.cat([z_ae, z_model], dim=1)
                else:
                    z = self.autoencoder.encode(x)

            # Get batch statistics and geometry
            stats_feat = self.stats_profiler.extract_batch_stats(self.model, x, y, self.cfg.epsilon)
            geom_feat = self.geom_profiler.batch_geometry(z, x_batch=x, model=self.model)

            # Predict cluster using teacher's clusterer
            cluster_label = self.clusterer.predict_final_labels(
                stats_feat.unsqueeze(0),
                geom_feat.unsqueeze(0)
            )[0].item()

            actual_batch_size = x.size(0)  # Use actual batch size, not config batch size

            if is_subset:
                batch_indices = []
                for i in range(actual_batch_size):
                    dataset_idx = batch_idx * self.cfg.batch_size + i
                    if dataset_idx < len(dataset.indices):
                        batch_indices.append(dataset.indices[dataset_idx])
            else:
                start_idx = batch_idx * self.cfg.batch_size
                end_idx = min(start_idx + actual_batch_size, len(dataset))
                batch_indices = list(range(start_idx, end_idx))

            # Single check for empty batch indices
            if not batch_indices:
                logger.warning(f"Empty batch indices at batch {batch_idx}, skipping")
                continue

            # Additional validation for subset mode
            if is_subset:
                max_valid_idx = len(dataset.indices)
                invalid_indices = [idx for idx in batch_indices if idx >= max_valid_idx or idx < 0]
                if invalid_indices:
                    logger.error(f"Invalid indices detected: {invalid_indices}, max valid: {max_valid_idx}")
                    batch_indices = [idx for idx in batch_indices if 0 <= idx < max_valid_idx]
                    if not batch_indices:
                        logger.warning(f"No valid indices remaining for batch {batch_idx}, skipping")
                        continue

            all_indices.extend(batch_indices)
            all_labels.extend([cluster_label] * len(batch_indices))

        # Build groups from predictions
        self.groups = defaultdict(list)
        for idx, label in zip(all_indices, all_labels):
            self.groups[label].append(idx)

        logger.info(f"Applied recipe to dataset: {len(self.groups)} groups")
        for k, v in self.groups.items():
            logger.info(f"  Group {k}: {len(v)} samples")


    def validate_recipe_transfer(self, sample_loader=None):
        logger.info("Validating recipe transfer...")

        # Check feature compatibility if hash exists
        if hasattr(self, '_loaded_recipe') and 'feature_compatibility_hash' in self._loaded_recipe:
            import hashlib
            current_hash = hashlib.md5()
            current_hash.update(str(self.autoencoder.state_dict().keys()).encode())
            current_hash.update(str([p.shape for p in self.model.parameters()][:5]).encode())
            if hasattr(self, 'simclr'):
                current_hash.update(str(self.simclr.state_dict().keys()).encode())

            if current_hash.hexdigest()[:16] != self._loaded_recipe['feature_compatibility_hash']:
                logger.warning("Feature architecture mismatch detected - recipe may be incompatible")

        # Use train loader if not provided
        if sample_loader is None:
            sample_loader = self.train_loader

        if not hasattr(self, 'clusterer'):
            raise ValueError("Clusterer not initialized")

        # STRICT CHECK: Must have scalers to proceed
        if self.clusterer.stats_mu is None or self.clusterer.geom_mu is None:
            logger.error("CRITICAL: Clusterer scalers not loaded - recipe is incomplete")
            logger.error(f"stats_mu exists: {self.clusterer.stats_mu is not None}")
            logger.error(f"geom_mu exists: {self.clusterer.geom_mu is not None}")
            logger.error("Cannot proceed with student training - recipe validation failed")
            return False

        # Check critical components are loaded
        # if not hasattr(self, 'clusterer') or self.clusterer.stats_mu is None:
        #     raise ValueError("Clusterer not properly loaded from recipe")

        if not hasattr(self, 'discovery_feature_mode'):
            raise ValueError("Feature mode not specified in recipe")

        if not hasattr(self, 'discovery_feature_dim'):
            raise ValueError("Feature dimension not specified in recipe")

        # Validation checks
        validation_passed = True
        errors = []

        try:
            # Take a small batch for testing
            for batch_idx, (x, y) in enumerate(sample_loader):
                if batch_idx > 0:  # Just test one batch
                    break

                x = x.to(self.device)[:16]  # Test with 16 samples
                y = y.to(self.device)[:16]

                logger.info(f"Testing with batch of {x.size(0)} samples")

                # Step 1: Validate feature extraction matches teacher's mode
                with torch.no_grad():
                    try:
                        if self.discovery_feature_mode == 'ae_model_simclr':
                            # Need all three components
                            if not hasattr(self, 'simclr'):
                                errors.append("SimCLR required but not loaded")
                                validation_passed = False
                            else:
                                z_ae = self.autoencoder.encode(x)
                                z_model = self._get_model_features(x)
                                z_simclr = self.simclr.extract_features(x)
                                z = torch.cat([z_ae, z_model, z_simclr], dim=1)
                                expected_dim = z_ae.size(1) + z_model.size(1) + z_simclr.size(1)

                        elif self.discovery_feature_mode == 'ae_model':
                            z_ae = self.autoencoder.encode(x)
                            z_model = self._get_model_features(x)
                            z = torch.cat([z_ae, z_model], dim=1)
                            expected_dim = z_ae.size(1) + z_model.size(1)

                        else:  # ae_only
                            z = self.autoencoder.encode(x)
                            expected_dim = z.size(1)

                        # Validate dimension matches
                        if z.size(1) != self.discovery_feature_dim:
                            errors.append(f"Feature dim mismatch: got {z.size(1)}, expected {self.discovery_feature_dim}")
                            validation_passed = False
                        else:
                            logger.info(f"✓ Feature extraction successful: dim={z.size(1)}")

                        # ADD THIS NEW CHECK:
                        # Strict feature consistency check
                        if hasattr(self, 'discovery_feature_dim') and self.discovery_feature_dim is not None:
                            if z.size(1) != self.discovery_feature_dim:
                                errors.append(f"CRITICAL: Feature dimension mismatch - expected {self.discovery_feature_dim}, got {z.size(1)}")
                                validation_passed = False
                                logger.error(f"Student model produces {z.size(1)}-dim features but teacher expects {self.discovery_feature_dim}")
                                logger.error("This will cause clustering prediction failures")

                        # Validate individual component dimensions if using combined features
                        if self.discovery_feature_mode == 'ae_model_simclr':
                            expected_ae_dim = getattr(self, 'cfg', {}).get('ae_latent_dim', 64)
                            expected_model_dim = 512  # ResNet18 feature dim
                            expected_simclr_dim = 128  # Default SimCLR dim

                            if z_ae.size(1) != expected_ae_dim:
                                errors.append(f"AE dimension mismatch: got {z_ae.size(1)}, expected {expected_ae_dim}")
                                validation_passed = False
                            if z_model.size(1) != expected_model_dim:
                                errors.append(f"Model dimension mismatch: got {z_model.size(1)}, expected {expected_model_dim}")
                                validation_passed = False
                            if hasattr(self, 'simclr') and z_simclr.size(1) != expected_simclr_dim:
                                errors.append(f"SimCLR dimension mismatch: got {z_simclr.size(1)}, expected {expected_simclr_dim}")
                                validation_passed = False


                    except Exception as e:
                        errors.append(f"Feature extraction failed: {str(e)}")
                        validation_passed = False

                # Step 2: Validate geometry profiler works with features
                try:
                    if self.geom.codebook is not None:
                        # Check codebook dimension compatibility
                        if self.geom.codebook.shape[1] != z.size(1):
                            errors.append(f"Codebook dim {self.geom.codebook.shape[1]} != feature dim {z.size(1)}")
                            validation_passed = False

                    # Try computing geometry features
                    geom_feat = self.geom.batch_geometry(z, x_batch=x, model=self.model)
                    logger.info(f"✓ Geometry features computed: dim={geom_feat.size(0)}")

                except Exception as e:
                    errors.append(f"Geometry profiler failed: {str(e)}")
                    validation_passed = False

                # Step 3: Validate stats profiler
                try:
                    stats_feat = self.stats.extract_batch_stats(self.model, x, y, self.cfg.epsilon)
                    logger.info(f"✓ Stats features computed: dim={stats_feat.size(0)}")
                except Exception as e:
                    errors.append(f"Stats profiler failed: {str(e)}")
                    validation_passed = False

                # Step 4: Validate clustering prediction
                try:
                    # Check if we can transform features
                    stats_norm, geom_norm = self.clusterer.transform_features(
                        stats_feat.unsqueeze(0),
                        geom_feat.unsqueeze(0)
                    )

                    # Try prediction
                    cluster_label = self.clusterer.predict_final_labels(
                        stats_feat.unsqueeze(0),
                        geom_feat.unsqueeze(0)
                    )[0].item()

                    logger.info(f"✓ Clustering prediction successful: assigned to cluster {cluster_label}")

                    # Check if cluster is in valid range
                    if cluster_label >= self.cfg.n_clusters:
                        errors.append(f"Invalid cluster label {cluster_label} >= {self.cfg.n_clusters}")
                        validation_passed = False

                except Exception as e:
                    errors.append(f"Clustering prediction failed: {str(e)}")
                    validation_passed = False

                # Step 5: Validate difficulty profiles if present
                if hasattr(self, 'difficulty_profiles') and self.difficulty_profiles:
                    try:
                        test_cluster = list(self.difficulty_profiles.keys())[0]
                        profile = self.difficulty_profiles[test_cluster]

                        # Check expected keys
                        expected_keys = ['asr', 'margin', 'grad_complexity', 'overall_difficulty']
                        missing_keys = [k for k in expected_keys if k not in profile]

                        if missing_keys:
                            logger.warning(f"Difficulty profile missing keys: {missing_keys}")
                        else:
                            logger.info(f"✓ Difficulty profiles validated")

                    except Exception as e:
                        errors.append(f"Difficulty profile validation failed: {str(e)}")

                # Step 6: Test full pipeline - simulate what happens during training
                try:
                    # This simulates getting features for a batch during training
                    with torch.no_grad():
                        # Get combined features again
                        if self.discovery_feature_mode == 'ae_model_simclr':
                            z_final = torch.cat([
                                self.autoencoder.encode(x),
                                self._get_model_features(x),
                                self.simclr.extract_features(x) if hasattr(self, 'simclr') else torch.zeros(x.size(0), 128).to(self.device)
                            ], dim=1)
                        elif self.discovery_feature_mode == 'ae_model':
                            z_final = torch.cat([
                                self.autoencoder.encode(x),
                                self._get_model_features(x)
                            ], dim=1)
                        else:
                            z_final = self.autoencoder.encode(x)

                        # Compute both features
                        final_geom = self.geom.batch_geometry(z_final, x_batch=x, model=self.model)
                        final_stats = self.stats.extract_batch_stats(self.model, x, y, self.cfg.epsilon)

                        # Predict final cluster
                        final_cluster = self.clusterer.predict_final_labels(
                            final_stats.unsqueeze(0),
                            final_geom.unsqueeze(0)
                        )[0].item()

                    logger.info(f"✓ Full pipeline test successful - cluster {final_cluster}")

                except Exception as e:
                    errors.append(f"Full pipeline test failed: {str(e)}")
                    validation_passed = False

        except Exception as e:
            errors.append(f"Validation failed with unexpected error: {str(e)}")
            validation_passed = False

        # Report results
        if validation_passed:
            logger.info("="*50)
            logger.info("RECIPE TRANSFER VALIDATION: PASSED ✓")
            logger.info("="*50)
            logger.info("Student can successfully reproduce teacher's:")
            logger.info(f"  - Feature extraction ({self.discovery_feature_mode})")
            logger.info(f"  - Feature dimension ({self.discovery_feature_dim})")
            logger.info(f"  - Geometry profiling")
            logger.info(f"  - Stats profiling")
            logger.info(f"  - Cluster prediction")
            return True
        else:
            logger.error("="*50)
            logger.error("RECIPE TRANSFER VALIDATION: FAILED ✗")
            logger.error("="*50)
            for error in errors:
                logger.error(f"  ✗ {error}")
            logger.error("Student cannot properly use teacher's recipe")
            return False

    def _get_model_features(self, x):
        """Helper to extract model features consistently"""
        with torch.no_grad():
            f = self.model.relu(self.model.bn1(self.model.conv1(
                Attacks.normalize(x, self.mean, self.std))))
            f = self.model.layer1(f)
            f = self.model.layer2(f)
            f = self.model.layer3(f)
            f = self.model.layer4(f)
            f = self.model.avgpool(f).view(f.size(0), -1)
        return f

    def update_cluster_difficulties(self):
        """Update difficulty scores based on recent performance"""
        alpha = 0.2  # Slower EMA for stability

        for cluster_id in range(self.cfg.n_clusters):
            if cluster_id not in self.cluster_robust_history:
                self.cluster_difficulties[cluster_id] = 0.5
                continue

            recent_robust = self.cluster_robust_history[cluster_id][-10:]
            recent_efficiency = self.cluster_pgd_efficiency[cluster_id][-10:]

            if recent_robust and recent_efficiency:
                avg_robust = np.mean(recent_robust)
                avg_efficiency = np.mean(recent_efficiency)

                # Observed difficulty (0 = easy, 1 = hard)
                observed = 1.0 - avg_robust  # Simple: just use robust acc

                # EMA update
                old = self.cluster_difficulties.get(cluster_id, 0.5)
                new = alpha * observed + (1 - alpha) * old

                # CRITICAL: Clamp to prevent saturation
                new = float(np.clip(new, 0.3, 0.85))  # Keep away from 1.0!

                self.cluster_difficulties[cluster_id] = new

    def finalize_clusters(self):
        """Finalize clusters and transition matrix after initial training"""
        logger.info("Finalizing clusters based on observed performance...")

        # Analyze which transitions actually worked
        successful_transitions = {}
        for (from_c, to_c), qualities in self.transition_qualities.items():
            if len(qualities) >= 5:  # Need enough samples
                successes = [q for q in qualities if q.get('success', False)]
                success_rate = len(successes) / len(qualities)
                avg_gain = np.mean([q['robust_gain'] for q in successes]) if successes else 0

                if success_rate > 0.6 and avg_gain > 0.01:  # Good transition
                    successful_transitions[(from_c, to_c)] = {
                        'success_rate': success_rate,
                        'avg_gain': avg_gain,
                        'count': len(qualities)
                    }

        # Rebuild T-matrix emphasizing successful transitions
        T_new = np.zeros_like(self.transition_scheduler.counts)
        for i in range(T_new.shape[0]):
            for j in range(T_new.shape[1]):
                base_count = self.transition_scheduler.counts[i, j]
                if (i, j) in successful_transitions:
                    # Boost successful transitions
                    boost = 1 + successful_transitions[(i, j)]['avg_gain'] * 10
                    T_new[i, j] = base_count * boost
                else:
                    T_new[i, j] = base_count * 0.5  # Penalize unsuccessful

        # Normalize
        row_sums = T_new.sum(axis=1, keepdims=True)
        row_sums[row_sums == 0] = 1
        self.T = torch.tensor(T_new / row_sums, device=self.device)

        logger.info(f"Finalized T-matrix with {len(successful_transitions)} proven transitions")
        for (f, t), stats in successful_transitions.items():
            logger.info(f"  {f}→{t}: {stats['success_rate']:.1%} success, +{stats['avg_gain']:.3f} robust")

    def find_cycles_in_T(self, T, min_length=3, min_weight=0.1):
        """Find cycles in transition matrix"""
        n = T.shape[0]
        cycles = []

        # Build adjacency from T
        edges = {}
        for i in range(n):
            edges[i] = []
            for j in range(n):
                if T[i,j] > min_weight and i != j:
                    edges[i].append(j)

        # DFS for cycles
        def dfs(start, current, visited, path):
            if len(path) >= min_length and start in edges[current]:
                cycles.append(path + [start])
                return
            if len(path) >= n:
                return
            for next_node in edges[current]:
                if next_node not in visited:
                    dfs(start, next_node, visited | {next_node}, path + [next_node])

        for start in range(n):
            dfs(start, start, {start}, [start])

        return cycles

    def _get_trades_beta(self):
        """Get current TRADES beta with warmup"""
        if self.current_epoch < self.cfg.beta_warmup_epochs:
            return 0.0
        return self.cfg.trades_beta

    def _current_trades_beta(self):
        return self._get_trades_beta()


    def _build_model(self):
        """Build ResNet18 with Dual-BN"""
        from torchvision.models.resnet import BasicBlock, ResNet

        norm_layer = DualBatchNorm2d if self.cfg.use_dual_bn else nn.BatchNorm2d

        class ResNetCustom(ResNet):
            def __init__(self):
                super().__init__(BasicBlock, [2, 2, 2, 2],
                               num_classes=10, norm_layer=norm_layer)
                self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
                self.maxpool = nn.Identity()

        model = ResNetCustom()

        for m in model.modules():
            if isinstance(m, nn.ReLU):
                m.inplace = False

        return model.to(self.device)

    def _build_lr_scheduler(self):
        """Build learning rate scheduler"""
        main_epochs = self.cfg.epochs - self.cfg.warmup_epochs

        if self.cfg.lr_schedule == "cosine":
            return torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer, T_max=main_epochs, eta_min=1e-4
            )
        else:
            milestones = [int(m * main_epochs) for m in self.cfg.lr_milestones]
            return torch.optim.lr_scheduler.MultiStepLR(
                self.optimizer, milestones=milestones, gamma=0.1
            )

    def _ema_update(self):
        """Update EMA model"""
        if self.model_ema is None:
            return

        with torch.no_grad():
            for (name, param), (_, param_ema) in zip(
                self.model.named_parameters(),
                self.model_ema.named_parameters()
            ):
                param_ema.data.mul_(self.cfg.ema_decay).add_(
                    param.data, alpha=1 - self.cfg.ema_decay
                )

    def train_autoencoder(self, dataloader):
        logger.info("Training denoising autoencoder...")
        self.autoencoder.train()

        ae_optimizer = torch.optim.Adam(self.autoencoder.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            ae_optimizer, patience=3, factor=0.5
        )

        best_loss = float('inf')
        patience_counter = 0

        for epoch in range(20):  # 20 epochs instead of 2
            total_loss = 0

            for batch_idx, (x, y) in enumerate(dataloader):
                x = x.to(self.device)

                if self.cfg.use_denoising:
                    # Multiple noise types
                    noise_type = np.random.choice(['gaussian', 'salt_pepper', 'fgsm'])

                    if noise_type == 'gaussian':
                        noise = torch.randn_like(x) * self.cfg.noise_level
                        x_noisy = torch.clamp(x + noise, 0, 1)
                    elif noise_type == 'salt_pepper':
                        mask = torch.rand_like(x) < 0.05
                        x_noisy = x.clone()
                        x_noisy[mask] = torch.rand_like(x[mask])
                    else:  # fgsm
                        if epoch > 2:  # Only after model has some training
                            x_fgsm = Attacks.fgsm(
                                self.model, x, y.to(self.device),
                                self.cfg.epsilon / 4,
                                self.mean, self.std,
                                use_adv_bn=False
                            )
                            x_noisy = 0.7 * x + 0.3 * x_fgsm
                        else:
                            x_noisy = x + torch.randn_like(x) * self.cfg.noise_level
                            x_noisy = torch.clamp(x_noisy, 0, 1)
                else:
                    x_noisy = x

                x_recon, z = self.autoencoder(x_noisy)

                # Reconstruction loss + regularization
                recon_loss = F.mse_loss(x_recon, x)
                reg_loss = 0.001 * z.norm(2, dim=1).mean()  # L2 regularization on latent
                loss = recon_loss + reg_loss

                ae_optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.autoencoder.parameters(), 1.0)
                ae_optimizer.step()

                total_loss += loss.item()

            avg_loss = total_loss / len(dataloader)
            scheduler.step(avg_loss)

            logger.info(f"AE Epoch {epoch+1}/20, Loss: {avg_loss:.4f}, LR: {ae_optimizer.param_groups[0]['lr']:.6f}")

            # Early stopping
            if avg_loss < best_loss:
                best_loss = avg_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= 5:
                    logger.info(f"Early stopping at epoch {epoch+1}")
                    break

    def finalize_transition_matrix(self):
        """Finalize and analyze the transition matrix"""
        # Compute T with reward weighting
        T = self.transition_scheduler.compute_T(normalize=True, use_rewards=True)

        # Store rewards in the scheduler
        if hasattr(self, 'transition_rewards'):
            self.transition_scheduler.transition_rewards = self.transition_rewards

        # Analyze the matrix properties
        logger.info(f"\nTransition Matrix Analysis (mode: {self.cfg.t_matrix_mode}):")
        logger.info(f"Recording started at epoch: "
                    f"{self.cfg.t_matrix_start_epoch if self.cfg.t_matrix_mode == 'late' else 0}")
        logger.info(f"Total transitions recorded: {int(self.transition_scheduler.counts.sum())}")

        # Check for strong patterns (high probability transitions)
        strong_transitions = []
        for i in range(T.shape[0]):
            for j in range(T.shape[1]):
                if T[i, j] > 0.3:  # Threshold for "strong" transition
                    strong_transitions.append((i, j, T[i, j]))

        if strong_transitions:
            logger.info("Strong transitions found:")
            for i, j, prob in strong_transitions:
                logger.info(f"  Cluster {i} → {j}: {prob:.3f}")

        # Check matrix entropy (higher = more random)
        T_entropy = -np.sum(T * np.log(T + 1e-8)) / (T.shape[0] * T.shape[1])
        logger.info(f"Transition matrix entropy: {T_entropy:.3f} "
                    f"({'random' if T_entropy > 0.8 else 'structured'})")

        return T


    @staticmethod
    def _to_buf(values, device):
        return torch.tensor(values).view(1, 3, 1, 1).to(device)

    def discovery_pass(self):
        device = self.device

        # Build discovery loader with smaller batch size for memory efficiency
        discovery_batch_size = min(64, self.cfg.batch_size)  # Reduce batch size
        disc_loader = DataLoader(
            self.train_set,
            batch_size=discovery_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,  # Set to 0 for WSL stability
            pin_memory=False,  # Disable for WSL
        )

        subset_indices = getattr(self.train_set, "indices", None)

        # Collect embeddings in chunks to avoid OOM
        all_Z = []
        self.model.eval()

        with torch.no_grad():
            for x, _ in disc_loader:
                x = x.to(device)
                if hasattr(self, "ae") and self.ae is not None:
                    z = self.ae.encode(x)
                else:
                    f = self.model.relu(self.model.bn1(self.model.conv1(
                        Attacks.normalize(x, self.mean, self.std))))
                    f = self.model.layer1(f)
                    f = self.model.layer2(f)
                    f = self.model.layer3(f)
                    f = self.model.layer4(f)
                    f = self.model.avgpool(f).view(f.size(0), -1)
                    z = f
                all_Z.append(z.cpu())

                # Aggressive cleanup
                del x, z
                if hasattr(locals(), 'f'):
                    del f
                if len(all_Z) % 5 == 0:
                    torch.cuda.empty_cache()

        Z = torch.cat(all_Z, dim=0)
        if Z.device != torch.device('cpu'):
            Z = Z.cpu()

        # Fit codebook (BoE) and prototypes (for OT signature) once
        if getattr(self.geom, "codebook", None) is None:
            self.geom.fit_codebook(Z)
        if getattr(self.geom, "prototypes", None) is None:
            self.geom.fit_prototypes(Z)
        # -------- 3) Build per-batch descriptors + robust batch index lists --------
        Xs_list, Xg_list, batch_ids_list = [], [], []
        # Recreate a fresh loader to iterate again from the start
        disc_loader = DataLoader(
            self.train_set,
            batch_size=self.cfg.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=self.cfg.num_workers,
            pin_memory=True,
        )

        self.model.eval()
        for b_idx, (x, y) in enumerate(disc_loader):
            x = x.to(device); y = y.to(device)

            # ---- robust TRUE dataset indices for THIS batch (Patch C) ----
            if subset_indices is not None:
                start = b_idx * self.cfg.batch_size
                actual_batch_size = x.size(0)
                batch_ids = []
                for i in range(actual_batch_size):
                    idx = start + i
                    if idx < len(subset_indices):
                        batch_ids.append(subset_indices[idx])
                    else:
                        batch_ids.append(subset_indices[0])  # Safe fallback
            else:
                start = b_idx * self.cfg.batch_size
                batch_ids = list(range(start, start + x.size(0)))
            batch_ids_list.append(batch_ids)

            # ---- per-image embeddings for THIS batch ----
            with torch.no_grad(), UseAdvBN(self.model, False):
                if getattr(self, "ae", None) is not None:
                    z = self.ae.encode(x)  # (B, d)
                else:
                    f = self.model.relu(self.model.bn1(self.model.conv1(Attacks.normalize(x, self.mean, self.std))))
                    f = self.model.layer1(f); f = self.model.layer2(f); f = self.model.layer3(f); f = self.model.layer4(f)
                    f = self.model.avgpool(f).view(f.size(0), -1)
                    z = f

            # ---- geometry vector (G1–G6 implemented inside) ----
            g_vec = self.geom.batch_geometry(z)  # (dg,)
            Xg_list.append(g_vec)

            # ---- stats vector (class-agnostic toggle lives inside StatsProfiler) ----
            s_vec = self.stats.extract_batch_stats(self.model, x, y, self.cfg.epsilon)
            Xs_list.append(s_vec)

        Xs = torch.stack(Xs_list, dim=0)  # (B, ds)
        Xg = torch.stack(Xg_list, dim=0)  # (B, dg)


        # -------- 4) Multi-view clustering to final labels + transition matrix T --------
        mvc = self.clusterer

        # Choose clustering based on feature type
        if self.cfg.cluster_feature_type == "geom_only":
            # Validate geometry features
            if not torch.isfinite(Xg).all():
                logger.warning("Non-finite geom features detected, cleaning...")
                Xg = torch.nan_to_num(Xg, nan=0.0, posinf=1.0, neginf=-1.0)

            geom_std = Xg.std(0)
            if (geom_std < 1e-6).any():
                logger.warning(f"Low variance in {(geom_std < 1e-6).sum()} geom features - adding noise")
                Xg = Xg + torch.randn_like(Xg) * 1e-5

            from sklearn.cluster import MiniBatchKMeans
            kmeans = MiniBatchKMeans(n_clusters=self.cfg.n_clusters, random_state=42, max_iter=20)
            labels = torch.tensor(kmeans.fit_predict(Xg.cpu().numpy())).long()
            T = torch.eye(self.cfg.n_clusters)

        elif self.cfg.cluster_feature_type == "stats_only":
            # Validate stats features
            if not torch.isfinite(Xs).all():
                logger.warning("Non-finite stats features detected, cleaning...")
                Xs = torch.nan_to_num(Xs, nan=0.0, posinf=1.0, neginf=-1.0)

            stats_std = Xs.std(0)
            if (stats_std < 1e-6).any():
                logger.warning(f"Low variance in {(stats_std < 1e-6).sum()} stats features - adding noise")
                Xs = Xs + torch.randn_like(Xs) * 1e-5

            from sklearn.cluster import MiniBatchKMeans
            kmeans = MiniBatchKMeans(n_clusters=self.cfg.n_clusters, random_state=42, max_iter=20)
            labels = torch.tensor(kmeans.fit_predict(Xs.cpu().numpy())).long()
            T = torch.eye(self.cfg.n_clusters)

        elif self.cfg.cluster_feature_type == "simclr_only" and hasattr(self, 'simclr'):
            # Extract SimCLR features
            simclr_features = []
            for x, _ in disc_loader:
                x = x.to(device)
                z_simclr = self.simclr.extract_features(x)
                simclr_features.append(z_simclr.cpu())
            Z_simclr = torch.cat(simclr_features, dim=0)

            # Validate SimCLR features
            if not torch.isfinite(Z_simclr).all():
                logger.warning("Non-finite SimCLR features detected, cleaning...")
                Z_simclr = torch.nan_to_num(Z_simclr, nan=0.0, posinf=1.0, neginf=-1.0)

            simclr_std = Z_simclr.std(0)
            if (simclr_std < 1e-6).any():
                logger.warning(f"Low variance in {(simclr_std < 1e-6).sum()} SimCLR features - adding noise")
                Z_simclr = Z_simclr + torch.randn_like(Z_simclr) * 1e-5

            from sklearn.cluster import MiniBatchKMeans
            kmeans = MiniBatchKMeans(n_clusters=self.cfg.n_clusters, random_state=42, max_iter=20)
            labels = torch.tensor(kmeans.fit_predict(Z_simclr.numpy())).long()
            T = torch.eye(self.cfg.n_clusters)

        elif self.cfg.cluster_feature_type == "confidence":
            confidence_features = self.extract_confidence_patterns(disc_loader)

            # Validate confidence features
            if not torch.isfinite(confidence_features).all():
                logger.warning("Non-finite confidence features detected, cleaning...")
                confidence_features = torch.nan_to_num(confidence_features, nan=0.0, posinf=1.0, neginf=-1.0)

            conf_std = confidence_features.std(0)
            if (conf_std < 1e-6).any():
                logger.warning(f"Low variance in {(conf_std < 1e-6).sum()} confidence features - adding noise")
                confidence_features = confidence_features + torch.randn_like(confidence_features) * 1e-5

            from sklearn.cluster import MiniBatchKMeans
            kmeans = MiniBatchKMeans(n_clusters=self.cfg.n_clusters, random_state=42, max_iter=20)
            labels = torch.tensor(kmeans.fit_predict(confidence_features.numpy())).long()
            T = torch.eye(self.cfg.n_clusters)

        elif self.cfg.cluster_feature_type == "adv_dynamics":
            # Cluster by adversarial perturbation patterns
            adv_features = []
            for x, y in disc_loader:
                x, y = x.to(device), y.to(device)
                eps_values = [self.cfg.epsilon * 0.25, self.cfg.epsilon * 0.5, self.cfg.epsilon * 0.75, self.cfg.epsilon]
                patterns = []
                for eps in eps_values:
                    x_adv = Attacks.fgsm(self.model, x, y, eps, self.mean, self.std)
                    with torch.no_grad():
                        clean_logits = self.model(Attacks.normalize(x, self.mean, self.std))
                        adv_logits = self.model(Attacks.normalize(x_adv, self.mean, self.std))
                        shift = (adv_logits - clean_logits).norm(dim=1).mean()
                        patterns.append(shift)
                adv_features.append(torch.tensor(patterns))

            Z_adv = torch.stack(adv_features)
            from sklearn.cluster import MiniBatchKMeans
            kmeans = MiniBatchKMeans(n_clusters=self.cfg.n_clusters, random_state=42)
            labels = torch.tensor(kmeans.fit_predict(Z_adv.cpu().numpy())).long()
            T = torch.eye(self.cfg.n_clusters)

        elif self.cfg.cluster_feature_type == "loss_landscape":
            # Cluster by local loss geometry
            landscape_features = []
            for x, y in disc_loader:
                x, y = x.to(device), y.to(device)
                x.requires_grad_(True)
                losses = []
                # Sample loss in random directions
                for _ in range(8):
                    noise = torch.randn_like(x) * 0.01
                    x_perturbed = (x + noise).clamp(0, 1)
                    with torch.no_grad():
                        loss = F.cross_entropy(self.model(Attacks.normalize(x_perturbed, self.mean, self.std)), y)
                    losses.append(loss.item())
                landscape_features.append(torch.tensor(losses))

            Z_landscape = torch.stack(landscape_features)
            from sklearn.cluster import MiniBatchKMeans
            kmeans = MiniBatchKMeans(n_clusters=self.cfg.n_clusters, random_state=42)
            labels = torch.tensor(kmeans.fit_predict(Z_landscape.cpu().numpy())).long()
            T = torch.eye(self.cfg.n_clusters)

        elif self.cfg.cluster_feature_type == "grad_coherence":
            # Cluster by gradient alignment patterns
            coherence_features = []
            for x, y in disc_loader:
                x, y = x.to(device), y.to(device)
                x = x.requires_grad_(True)

                logits = self.model(Attacks.normalize(x, self.mean, self.std))
                loss = F.cross_entropy(logits, y, reduction='none')

                # Compute gradient coherence (simplified for batch-level)
                batch_grad = torch.autograd.grad(loss.sum(), x)[0]
                grad_norm = batch_grad.view(x.size(0), -1).norm(dim=1)

                # Stats about gradient diversity in batch
                features = torch.tensor([
                    grad_norm.mean().item(),
                    grad_norm.std().item(),
                    grad_norm.max().item(),
                    grad_norm.min().item()
                ])
                coherence_features.append(features)

            Z_coherence = torch.stack(coherence_features)
            from sklearn.cluster import MiniBatchKMeans
            kmeans = MiniBatchKMeans(n_clusters=self.cfg.n_clusters, random_state=42)
            labels = torch.tensor(kmeans.fit_predict(Z_coherence.cpu().numpy())).long()
            T = torch.eye(self.cfg.n_clusters)

        elif self.cfg.cluster_feature_type == "consistency":
            consistency_features = []
            for x, y in disc_loader:
                predictions = []
                for _ in range(5):
                    noise = torch.randn_like(x) * 0.01
                    x_noisy = (x + noise).clamp(0, 1)
                    with torch.no_grad():
                        pred = self.model(Attacks.normalize(x_noisy, self.mean, self.std)).argmax(1)
                    predictions.append(pred)
                # Measure prediction variance
                predictions = torch.stack(predictions)
                consistency = (predictions.std(0) == 0).float().mean()
                consistency_features.append(consistency.item())

        elif self.cfg.cluster_feature_type == "activations":
            # Use specific layer activations as features
            activation_features = []
            for x, y in disc_loader:
                x = x.to(device)
                with torch.no_grad():
                    # Get layer 3 activations
                    h = self.model.layer3(self.model.layer2(self.model.layer1(
                        self.model.relu(self.model.bn1(self.model.conv1(
                            Attacks.normalize(x, self.mean, self.std)))))))
                    # Global average pooling
                    h_pooled = h.mean(dim=[2, 3])
                activation_features.append(h_pooled.cpu())

        elif self.cfg.cluster_feature_type == "adaptive_comprehensive":
            groups, labels, T, feature_weights = self.comprehensive_discovery_with_learning()

            # Store for later use
            self.learned_feature_weights = feature_weights

            # Save to file for analysis
            import json
            weights_path = self.cfg.experiment_dir / f'feature_weights_epoch_{self.current_epoch}.json'
            with open(weights_path, 'w') as f:
                json.dump(feature_weights, f, indent=2)

            logger.info(f"Saved feature weights to {weights_path}")

        else:  # multi_view (default)
            # Validate both feature sets
            if not torch.isfinite(Xs).all() or not torch.isfinite(Xg).all():
                logger.warning("Non-finite features detected, replacing with zeros")
                Xs = torch.nan_to_num(Xs, nan=0.0, posinf=1.0, neginf=-1.0)
                Xg = torch.nan_to_num(Xg, nan=0.0, posinf=1.0, neginf=-1.0)

            if Xs.std(0).min() < 1e-6 or Xg.std(0).min() < 1e-6:
                logger.warning("Near-zero variance features detected")

            stats_std = Xs.std(0)
            geom_std = Xg.std(0)
            if (stats_std < 1e-6).any():
                logger.warning(f"Low variance in {(stats_std < 1e-6).sum()} stats features - adding noise")
                Xs = Xs + torch.randn_like(Xs) * 1e-5

            if (geom_std < 1e-6).any():
                logger.warning(f"Low variance in {(geom_std < 1e-6).sum()} geom features - adding noise")
                Xg = Xg + torch.randn_like(Xg) * 1e-5

            try:
                mvc.fit(Xs, Xg, batch_ids_list, epoch=self.current_epoch)
                labels = mvc.final_labels
                T = mvc.get_transition()
            except Exception as e:
                logger.error(f"Clustering failed: {type(e).__name__}: {str(e)}")

                # Initialize proper fallback state
                n_batches = len(batch_ids_list)
                labels = torch.randint(0, self.cfg.n_clusters, (n_batches,))
                mvc.final_labels = labels
                mvc.T = torch.eye(self.cfg.n_clusters).to(self.device)

                # Set minimal scalers to prevent downstream errors
                mvc.stats_mu = torch.zeros(Xs.shape[1]).to(self.device)
                mvc.stats_sigma = torch.ones(Xs.shape[1]).to(self.device)
                mvc.geom_mu = torch.zeros(Xg.shape[1]).to(self.device)
                mvc.geom_sigma = torch.ones(Xg.shape[1]).to(self.device)
                mvc.pair_to_final = {(i % self.cfg.K_stats, i % self.cfg.K_geom): i % self.cfg.n_clusters
                                    for i in range(max(self.cfg.n_clusters, self.cfg.K_stats, self.cfg.K_geom))}

                logger.info("Fallback clustering with random assignment")
                labels = mvc.final_labels
                T = mvc.get_transition()

        # -------- 5) Build causal groups: final_label → list of TRUE-index batches --------
        Kf = int(labels.max().item()) + 1 if labels.numel() > 0 else self.cfg.K_final
        Kf = max(Kf, self.cfg.n_clusters)  # Ensure we have at least n_clusters groups
                # Fix transition scheduler size mismatch
        if hasattr(self, 'transition_scheduler'):
            if self.transition_scheduler.n_clusters != Kf:
                logger.info(f"Reinitializing transition scheduler: {self.transition_scheduler.n_clusters} -> {Kf} clusters")
                self.transition_scheduler = TransitionScheduler(Kf, self.cfg)
        else:
            self.transition_scheduler = TransitionScheduler(Kf, self.cfg)
        groups = {}
        for k in range(Kf):
            groups[k] = []
        for b, lab in enumerate(labels.tolist()):
            if lab < Kf:  # Safety check
                groups[lab].extend(batch_ids_list[b])

        # (Optional) Log a quick summary for HITL/analysis
        sizes = {k: len(v) for k, v in groups.items()}
        print(f"[Discovery] Final groups={Kf} | sizes={sizes} | T row sums={T.sum(1).cpu().numpy().round(3)}")

        cluster_save_path = self.cfg.experiment_dir / f'cluster_labels_epoch_{self.current_epoch}.npy'
        np.save(cluster_save_path, labels.cpu().numpy())
        logger.info(f"Saved cluster labels to {cluster_save_path}")

        return groups, labels, T, Xs, Xg

    def comprehensive_discovery_with_learning(self):
        import numpy as np
        """Discovery that evaluates all feature types and learns optimal weighting"""
        device = self.device

        # Create discovery loader
        disc_loader = DataLoader(
            self.train_set,
            batch_size=self.cfg.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0
        )

        # Store all feature types separately
        feature_collections = {
            'stats': [],
            'geom': [],
            'confidence': [],
            'adv_dynamics': [],
            'grad_coherence': [],
            'activations': [],
            'consistency': [],
            'loss_landscape': []
        }

        batch_ids_list = []

        logger.info("Collecting all feature types...")

        for batch_idx, (x, y) in enumerate(disc_loader):
            if batch_idx % 20 == 0:
                logger.info(f"Processing batch {batch_idx}/{len(disc_loader)}")

            x, y = x.to(device), y.to(device)

            # Collect batch IDs
            batch_ids = list(range(batch_idx * x.size(0), batch_idx * x.size(0) + x.size(0)))
            batch_ids_list.append(batch_ids)

            # 1. Stats features
            s_vec = self.stats.extract_batch_stats(self.model, x, y, self.cfg.epsilon)
            feature_collections['stats'].append(s_vec.cpu())

            # 2. Geometry features
            with torch.no_grad():
                z = self.autoencoder.encode(x)
            g_vec = self.geom.batch_geometry(z, x_batch=x, model=self.model)
            feature_collections['geom'].append(g_vec.cpu())

            # 3. Confidence patterns
            with torch.no_grad():
                logits = self.model(Attacks.normalize(x, self.mean, self.std))
                probs = F.softmax(logits, dim=1)
                conf_features = torch.tensor([
                    probs.max(dim=1)[0].mean().item(),
                    -(probs * probs.log()).sum(dim=1).mean().item(),
                    probs.std(dim=1).mean().item()
                ])
            feature_collections['confidence'].append(conf_features)

            # 4. Adversarial dynamics
            eps_values = [self.cfg.epsilon * 0.25, self.cfg.epsilon * 0.5, self.cfg.epsilon * 0.75, self.cfg.epsilon]
            adv_patterns = []
            for eps in eps_values:
                x_adv = Attacks.fgsm(self.model, x, y, eps, self.mean, self.std)
                with torch.no_grad():
                    adv_logits = self.model(Attacks.normalize(x_adv, self.mean, self.std))
                    shift = (adv_logits - logits).norm(dim=1).mean().item()
                    adv_patterns.append(shift)
            feature_collections['adv_dynamics'].append(torch.tensor(adv_patterns))

            # 5. Gradient coherence
            x_grad = x.clone().requires_grad_(True)
            loss = F.cross_entropy(self.model(Attacks.normalize(x_grad, self.mean, self.std)), y)
            grad = torch.autograd.grad(loss, x_grad)[0]
            grad_norm = grad.view(x.size(0), -1).norm(dim=1)
            grad_features = torch.tensor([
                grad_norm.mean().item(),
                grad_norm.std().item(),
                grad_norm.max().item() / (grad_norm.min().item() + 1e-8)
            ])
            feature_collections['grad_coherence'].append(grad_features)

            # 6. Activation patterns (from layer 3)
            with torch.no_grad():
                h = self.model.layer3(self.model.layer2(self.model.layer1(
                    self.model.relu(self.model.bn1(self.model.conv1(
                        Attacks.normalize(x, self.mean, self.std)))))))
                h_pooled = h.mean(dim=[2, 3])
                # Reduce dimension with PCA-like projection
                h_reduced = h_pooled.mean(0)  # Batch-level summary
            feature_collections['activations'].append(h_reduced.cpu())

            # 7. Prediction consistency
            predictions = []
            for _ in range(5):
                noise = torch.randn_like(x) * 0.01
                x_noisy = (x + noise).clamp(0, 1)
                with torch.no_grad():
                    pred = self.model(Attacks.normalize(x_noisy, self.mean, self.std)).argmax(1)
                predictions.append(pred)

            # Calculate consistency differently
            predictions = torch.stack(predictions)  # Shape: (5, batch_size)
            # For each sample, count how many predictions agree with the mode
            mode_pred = torch.mode(predictions, dim=0)[0]  # Most common prediction
            agreements = (predictions == mode_pred.unsqueeze(0)).float().mean(0)  # How often each sample agrees
            consistency = agreements.mean().item()  # Average consistency across batch

            # Calculate entropy of predictions
            pred_counts = torch.zeros(x.size(0), 10, device=x.device)  # 10 classes for CIFAR10
            for i in range(5):
                pred_counts.scatter_add_(1, predictions[i].unsqueeze(1), torch.ones_like(predictions[i].unsqueeze(1), dtype=torch.float))
            pred_probs = pred_counts / 5.0
            entropy = -(pred_probs * (pred_probs + 1e-8).log()).sum(1).mean().item()

            feature_collections['consistency'].append(torch.tensor([consistency, entropy]))

            # 8. Loss landscape
            losses = []
            for _ in range(5):
                noise = torch.randn_like(x) * 0.01
                x_perturbed = (x + noise).clamp(0, 1)
                with torch.no_grad():
                    loss_val = F.cross_entropy(
                        self.model(Attacks.normalize(x_perturbed, self.mean, self.std)), y
                    ).item()
                losses.append(loss_val)
            landscape_features = torch.tensor([
                np.mean(losses),
                np.std(losses),
                max(losses) - min(losses)
            ])
            feature_collections['loss_landscape'].append(landscape_features)

        # Stack all features
        logger.info("Stacking and normalizing features...")
        stacked_features = {}
        for name, feat_list in feature_collections.items():
            stacked = torch.stack(feat_list)
            # Normalize each feature type
            stacked = (stacked - stacked.mean(0)) / (stacked.std(0) + 1e-8)
            stacked_features[name] = stacked

        logger.info("Learning optimal feature weights based on type discovery...")
        feature_scores = {}
        optimal_k_per_feature = {}

        def _evaluate_feature_for_learning_local(name, features, sample_size=500):
            """Local evaluation function if method doesn't exist"""
            # Simple variance-based scoring as fallback
            variance = features.var(0).mean().item()
            return variance, min(5, max(2, int(variance * 10)))

        # Use the class method if it exists, otherwise use local fallback
        evaluate_func = (self._evaluate_feature_for_learning
                        if hasattr(self, '_evaluate_feature_for_learning')
                        else _evaluate_feature_for_learning_local)

        # Create feature weighter with correct parameters
        feature_weighter = type('FeatureWeighter', (), {
            'evaluate_feature_for_learning': lambda self, name, features, sample_size=500:
                self.trainer._evaluate_feature_for_learning(name, features, sample_size),
            'trainer': self
        })()

        # Evaluate each feature type individually
        logger.info("Evaluating individual feature types...")
        cluster_results = {}

        for name, features in stacked_features.items():
            if torch.isnan(features).any() or features.std() < 1e-8:
                logger.warning(f"Skipping {name} due to invalid features")
                feature_scores[name] = 0.0
                continue

            logger.info(f"Evaluating {name} for learning potential...")
            score, optimal_k = evaluate_func(name, features)  # Use the function we defined
            feature_scores[name] = score
            optimal_k_per_feature[name] = optimal_k
            logger.info(f"  {name}: score={score:.3f}, optimal_k={optimal_k}")

        # Learn optimal feature weights through comprehensive testing
        logger.info("Learning optimal feature weights through comprehensive testing...")

        # Initialize weights
        best_weights = {name: 1.0 for name in feature_collections.keys()}
        best_score = -float('inf')
        best_features = None
        best_active = list(feature_collections.keys())

        # Grid search over weight combinations
        # Use gradient-free optimization (CMA-ES style) for continuous weights
        logger.info("Learning continuous feature weights using evolutionary optimization...")

        import numpy as np
        from scipy.optimize import differential_evolution

        feature_names = list(feature_collections.keys())
        n_features = len(feature_names)

        # Initialize best tracking
        best_weights = {name: 1.0 for name in feature_names}
        best_score = -float('inf')
        best_features = None
        best_active = feature_names
        best_labels = None

        def evaluate_weights(weight_array):
            """Objective function for weight optimization"""
            # Apply weights to features
            weighted_features = []
            active_features = []

            for i, (name, weight) in enumerate(zip(feature_names, weight_array)):
                if weight > 0.01:  # Threshold for considering active
                    weighted_features.append(stacked_features[name] * weight)
                    active_features.append(name)

            if not weighted_features:
                return -1000.0  # Penalty for all-zero weights

            # Combine weighted features
            combined = torch.cat([f.flatten(1) for f in weighted_features], dim=1)

            # Normalize
            combined = (combined - combined.mean(0)) / (combined.std(0) + 1e-8)

            # Quick clustering
            from sklearn.cluster import MiniBatchKMeans
            kmeans = MiniBatchKMeans(n_clusters=self.cfg.n_clusters, random_state=42, max_iter=20)
            test_labels = kmeans.fit_predict(combined.cpu().numpy())

            # Evaluate clustering quality
            score = self._evaluate_clustering_quality(test_labels, combined)

            # Store if best
            nonlocal best_score, best_weights, best_features, best_active, best_labels
            if score > best_score:
                best_score = score
                best_weights = {name: w for name, w in zip(feature_names, weight_array)}
                best_features = combined
                best_active = active_features
                best_labels = test_labels
                logger.info(f"  New best score: {score:.3f} with weights: {[f'{w:.2f}' for w in weight_array]}")

            return -score  # Minimize negative score

        # Use differential evolution for continuous optimization
        bounds = [(0.0, 1.0) for _ in range(n_features)]

        # Initial guess: equal weights
        x0 = np.array([1.0 / n_features] * n_features)

        logger.info(f"Optimizing {n_features} feature weights...")

        result = differential_evolution(
            evaluate_weights,
            bounds,
            maxiter=30,  # Limit iterations for speed
            popsize=10,  # Population size
            seed=42,
            workers=1,  # Use single worker for determinism
            updating='deferred',
            disp=True
        )

        # Extract final best weights
        optimal_weights = result.x
        evaluate_weights(optimal_weights)  # Ensure best is stored

        logger.info(f"Optimal weights found: {best_weights}")
        logger.info(f"Active features with non-zero weights: {best_active}")
        logger.info(f"Best combination score: {best_score:.3f}")


        # Use the best weighted combination for final clustering
        final_labels = torch.tensor(best_labels).long()
        final_k = len(np.unique(best_labels))  # MUST DEFINE final_k BEFORE using it!

        logger.info(f"Final clustering resulted in {final_k} clusters")

        # Build groups
        groups = {k: [] for k in range(final_k)}
        for b, lab in enumerate(final_labels.tolist()):
            groups[lab].extend(batch_ids_list[b])

        # Store the discovery mode and configuration
        self.discovery_feature_mode = 'adaptive_comprehensive'
        self.discovery_feature_dim = best_features.shape[1]

        # The ordering test logic should be DEFERRED but we need to handle both cases
        if self.current_epoch > 5:  # Only test orderings if we have a trained model
            logger.info("Testing cluster orderings with trained model...")
            ordering_tester = ClusterOrderingTester(self)
            try:
                best_ordering, all_scores = ordering_tester.find_best_orderings(
                    groups,
                    n_test=5,
                    test_epochs=2
                )
                logger.info(f"Ordering test completed, found best: {best_ordering[0] if isinstance(best_ordering, tuple) else best_ordering}")
            except Exception as e:
                logger.warning(f"Ordering test failed: {e}, using default")
                best_ordering = ("cyclical", list(range(final_k)) * 100)
                all_scores = {"cyclical": 0.0}
        else:
            # First discovery - can't test orderings without trained model
            logger.info("First discovery (epoch <= 5) - deferring ordering tests")
            best_ordering = ("cyclical_initial", list(range(final_k)) * 100)
            all_scores = {"cyclical_initial": 0.0}

        # Unpack and validate the ordering result
        if isinstance(best_ordering, tuple) and len(best_ordering) == 2:
            best_order_name, best_order_sequence = best_ordering
            if not isinstance(best_order_sequence, (list, tuple)):
                logger.error(f"best_order_sequence is {type(best_order_sequence)}, not a list!")
                best_order_sequence = list(range(final_k)) * 100
        elif isinstance(best_ordering, str):
            best_order_name = best_ordering
            best_order_sequence = list(range(final_k)) * 100
        else:
            logger.error(f"Unexpected best_ordering type: {type(best_ordering)}")
            best_order_name = "fallback"
            best_order_sequence = list(range(final_k)) * 100

        # Ensure it's a list and reasonable length
        best_order_sequence = list(best_order_sequence)[:1000]

        # Build T matrix based on best ordering (or initial if first discovery)
        logger.info("Building transition matrix from ordering...")
        T = np.zeros((final_k, final_k))
        if len(best_order_sequence) > 1:
            for i in range(len(best_order_sequence) - 1):
                from_c = best_order_sequence[i] % final_k  # Ensure valid index
                to_c = best_order_sequence[i + 1] % final_k
                T[from_c, to_c] += 1
        else:
            T = np.eye(final_k)  # Identity if no sequence

        # Normalize rows to get transition probabilities
        row_sums = T.sum(axis=1, keepdims=True)
        row_sums[row_sums == 0] = 1  # Avoid division by zero
        T = T / row_sums
        T = torch.tensor(T, dtype=torch.float32)

        # Store everything - including all the learned weights!
        self.feature_weights = {
            'weights': best_weights,  # The actual learned weights for each feature
            'active_features': best_active,  # Features with non-zero weights
            'individual_scores': feature_scores,  # Individual feature scores from earlier
            'combination_score': best_score,  # Best combination score
            'best_ordering': best_order_name,
            'ordering_scores': all_scores,
            'best_order_sequence': best_order_sequence,  # ADD THIS - the actual sequence!
            'feature_combination_mode': 'adaptive_comprehensive',
            'n_combinations_tested': result.nit if 'result' in locals() else 30,
            'discovered_at_epoch': self.current_epoch
        }

        # Store for use during training
        self.discovered_best_order = best_order_sequence
        self.discovered_feature_weights = self.feature_weights

        # Log what we discovered
        logger.info(f"Discovery complete: {len(groups)} groups created")
        logger.info(f"Feature weight summary:")
        for feat_name, weight in best_weights.items():
            if weight > 0:
                logger.info(f"  {feat_name}: {weight:.2f}")

        return groups, final_labels, T, self.feature_weights

    def _evaluate_clustering_quality(self, labels, features):
        """Evaluate clustering quality using multiple metrics and LEARN their importance"""
        from sklearn.metrics import silhouette_score, calinski_harabasz_score

        # Compute individual metric scores
        metric_scores = {}

        # 1. Silhouette score (cluster separation)
        try:
            if len(np.unique(labels)) > 1:
                metric_scores['silhouette'] = silhouette_score(
                    features.cpu().numpy(), labels,
                    sample_size=min(1000, len(labels))
                )
            else:
                metric_scores['silhouette'] = 0.0
        except:
            metric_scores['silhouette'] = 0.0

        # 2. Calinski-Harabasz (between/within cluster variance ratio)
        try:
            if len(np.unique(labels)) > 1:
                ch = calinski_harabasz_score(features.cpu().numpy(), labels)
                metric_scores['calinski'] = min(1.0, ch / 1000.0)  # Normalize
            else:
                metric_scores['calinski'] = 0.0
        except:
            metric_scores['calinski'] = 0.0

        # 3. Diversity score
        metric_scores['diversity'] = self._compute_cluster_diversity_score(labels, features)

        # 4. Learning gradient - can the model learn from this clustering?
        metric_scores['learning_gradient'] = self._test_learning_gradient(labels, features)

        # 5. Robustness variance - do clusters have different robustness?
        metric_scores['robustness_variance'] = self._test_robustness_variance(labels)

        # Now LEARN the best weights for these metrics
        if not hasattr(self, 'metric_weights_cache'):
            self.metric_weights_cache = {}

        # Try different metric weight combinations
        best_final_score = -float('inf')
        best_metric_weights = None

        # Test combinations of metric weights
        for w1 in [0.0, 0.2, 0.4, 0.6]:
            for w2 in [0.0, 0.2, 0.4, 0.6]:
                for w3 in [0.0, 0.2, 0.4, 0.6]:
                    if w1 + w2 + w3 > 1.0:
                        continue
                    w4 = max(0, 1.0 - w1 - w2 - w3) * 0.5
                    w5 = max(0, 1.0 - w1 - w2 - w3) * 0.5

                    # Compute weighted score
                    score = (w1 * metric_scores['silhouette'] +
                            w2 * metric_scores['calinski'] +
                            w3 * metric_scores['diversity'] +
                            w4 * metric_scores['learning_gradient'] +
                            w5 * metric_scores['robustness_variance'])

                    # Test if this weighting correlates with actual learning success
                    if self.current_epoch > 0:  # Only if we have training history
                        # Check correlation with actual robust accuracy improvement
                        correlation_bonus = self._check_metric_correlation(score, labels)
                        score += correlation_bonus

                    if score > best_final_score:
                        best_final_score = score
                        best_metric_weights = [w1, w2, w3, w4, w5]

        # Store the best weights found
        if best_metric_weights:
            self.discovered_metric_weights = {
                'silhouette': best_metric_weights[0],
                'calinski': best_metric_weights[1],
                'diversity': best_metric_weights[2],
                'learning_gradient': best_metric_weights[3],
                'robustness_variance': best_metric_weights[4]
            }

            # Log what we learned
            if not hasattr(self, '_logged_metric_weights'):
                logger.info(f"Discovered metric weights: {self.discovered_metric_weights}")
                self._logged_metric_weights = True

        return best_final_score

    def _test_learning_gradient(self, labels, features):
        """Test if model can learn different things from different clusters"""
        unique_labels = np.unique(labels)
        if len(unique_labels) <= 1:
            return 0.0

        gradients = []
        for label in unique_labels[:5]:  # Sample up to 5 clusters
            indices = np.where(labels == label)[0][:10]  # Sample 10 from each

            cluster_grads = []
            for idx in indices:
                x, y = self.train_set[idx]
                x = x.unsqueeze(0).to(self.device).requires_grad_(True)
                y = torch.tensor([y]).to(self.device)

                loss = F.cross_entropy(
                    self.model(Attacks.normalize(x, self.mean, self.std)), y
                )
                grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
                cluster_grads.append(grad.norm().item())

            if cluster_grads:
                gradients.append(np.mean(cluster_grads))

        # Return variance in gradients - higher = more diverse learning signals
        return np.std(gradients) if len(gradients) > 1 else 0.0

    def _test_robustness_variance(self, labels):
        """Test if clusters have different robustness characteristics"""
        unique_labels = np.unique(labels)
        if len(unique_labels) <= 1:
            return 0.0

        cluster_robustness = []
        for label in unique_labels[:5]:
            indices = np.where(labels == label)[0][:10]

            robust_scores = []
            for idx in indices:
                x, y = self.train_set[idx]
                x = x.unsqueeze(0).to(self.device)
                y = torch.tensor([y]).to(self.device)

                with torch.no_grad():
                    x_adv = Attacks.fgsm(self.model, x, y, self.cfg.epsilon * 0.5,
                                        self.mean, self.std)
                    logits_adv = self.model(Attacks.normalize(x_adv, self.mean, self.std))
                    robust = (logits_adv.argmax(1) == y).float().mean().item()
                    robust_scores.append(robust)

            if robust_scores:
                cluster_robustness.append(np.mean(robust_scores))

        return np.std(cluster_robustness) if len(cluster_robustness) > 1 else 0.0

    def _check_metric_correlation(self, score, labels):
        """Check if this metric score correlates with actual learning improvement"""
        if not hasattr(self, 'metric_score_history'):
            self.metric_score_history = []
            self.robust_acc_history = []
            return 0.0

        # Store current score and performance
        self.metric_score_history.append(score)
        if hasattr(self, 'last_robust_acc'):
            self.robust_acc_history.append(self.last_robust_acc)

        # Compute correlation if we have enough history
        if len(self.metric_score_history) > 5 and len(self.robust_acc_history) > 5:
            from scipy.stats import spearmanr
            correlation, _ = spearmanr(
                self.metric_score_history[-5:],
                self.robust_acc_history[-5:]
            )
            return correlation * 0.1  # Small bonus for correlation

        return 0.0

    def _compute_cluster_diversity_score(self, labels, features):
        """Check if clusters have meaningfully different characteristics"""
        unique_labels = np.unique(labels)
        if len(unique_labels) <= 1:
            return 0.0

        # Compute cluster centroids
        centroids = []
        for label in unique_labels:
            mask = (labels == label)
            if np.sum(mask) > 0:
                cluster_features = features[mask].mean(0)
                centroids.append(cluster_features)

        if len(centroids) < 2:
            return 0.0

        # Compute pairwise distances between centroids
        distances = []
        for i in range(len(centroids)):
            for j in range(i+1, len(centroids)):
                dist = torch.norm(centroids[i] - centroids[j]).item()
                distances.append(dist)

        # Higher average distance = more diverse clusters
        avg_distance = np.mean(distances) if distances else 0.0

        # Normalize to [0, 1] range
        return min(1.0, avg_distance / 10.0)

    def _compute_cluster_robustness_variance(self, labels):
        """Check if clusters have different robustness characteristics"""
        cluster_robust = defaultdict(list)

        # Sample some data points
        unique_labels = np.unique(labels)
        for label in unique_labels[:5]:  # Test up to 5 clusters
            indices = np.where(labels == label)[0][:20]  # Sample 20 per cluster

            for idx in indices:
                # Get actual data sample
                x, y = self.train_set[idx]
                x = x.unsqueeze(0).to(self.device)
                y = torch.tensor([y]).to(self.device)

                # Quick robustness test
                with torch.no_grad():
                    x_adv = Attacks.fgsm(self.model, x, y, self.cfg.epsilon * 0.5,
                                        self.mean, self.std)
                    logits_clean = self.model(Attacks.normalize(x, self.mean, self.std))
                    logits_adv = self.model(Attacks.normalize(x_adv, self.mean, self.std))
                    robust = (logits_adv.argmax(1) == y).item()
                    cluster_robust[label].append(robust)

        # Compute variance between clusters
        cluster_means = [np.mean(cluster_robust[l]) for l in unique_labels if l in cluster_robust]

        if len(cluster_means) > 1:
            return np.std(cluster_means)
        return 0.0

    def _profile_discovered_type(self, sample_indices):
        """Create a detailed profile of a discovered type"""

        robust_weak = []
        robust_strong = []
        grad_norms = []
        learning_responses = []

        for idx in sample_indices[:50]:  # Sample for efficiency
            x, y = self.train_set[idx]
            x = x.unsqueeze(0).to(self.device)
            y = torch.tensor([y]).to(self.device)

            # Test robustness at different levels
            with torch.no_grad():
                # Weak attack
                x_weak = Attacks.fgsm(self.model, x, y, self.cfg.epsilon * 0.25,
                                    self.mean, self.std)
                pred_weak = self.model(Attacks.normalize(x_weak, self.mean, self.std))
                robust_weak.append((pred_weak.argmax(1) == y).item())

                # Strong attack
                x_strong = Attacks.fgsm(self.model, x, y, self.cfg.epsilon,
                                      self.mean, self.std)
                pred_strong = self.model(Attacks.normalize(x_strong, self.mean, self.std))
                robust_strong.append((pred_strong.argmax(1) == y).item())

            # Gradient characteristics
            x.requires_grad_(True)
            loss = F.cross_entropy(self.model(Attacks.normalize(x, self.mean, self.std)), y)
            grad = torch.autograd.grad(loss, x)[0]
            grad_norms.append(grad.norm().item())

        # Determine type description based on profile
        avg_robust_weak = np.mean(robust_weak)
        avg_robust_strong = np.mean(robust_strong)
        avg_grad = np.mean(grad_norms)

        if avg_robust_strong > 0.7:
            description = "Naturally robust"
        elif avg_robust_weak > 0.7 and avg_robust_strong < 0.3:
            description = "Vulnerable to strong attacks"
        elif avg_grad < np.percentile(grad_norms, 25):
            description = "Low gradient (easy to learn)"
        elif avg_grad > np.percentile(grad_norms, 75):
            description = "High gradient (hard to learn)"
        else:
            description = "Mixed difficulty"

        return {
            'description': description,
            'robust_to_weak': avg_robust_weak,
            'robust_to_strong': avg_robust_strong,
            'avg_gradient': avg_grad,
            'learning_rate': 1.0 / (1.0 + avg_grad),  # Proxy for how fast it learns
            'sample_indices': sample_indices
        }

    def _build_type_transition_matrix(self, type_profiles, n_types):
        """Build transition matrix based on pedagogical relationships between types"""

        T = np.zeros((n_types, n_types))

        # Order types by learning difficulty
        difficulty_order = sorted(type_profiles.items(),
                                key=lambda x: x[1]['robust_to_strong'])

        # Create curriculum flow
        for i, (type_i, profile_i) in enumerate(difficulty_order):
            for j, (type_j, profile_j) in enumerate(difficulty_order):
                if i == j:
                    T[type_i, type_j] = 0.1  # Small self-loop
                elif j == i + 1:
                    # Strong transition to next difficulty
                    T[type_i, type_j] = 0.5
                elif profile_j['learning_rate'] > profile_i['learning_rate']:
                    # Moderate transition to easier-to-learn types
                    T[type_i, type_j] = 0.2
                else:
                    T[type_i, type_j] = 0.05  # Weak transition

        # Normalize rows
        T = T / (T.sum(axis=1, keepdims=True) + 1e-8)

        return torch.tensor(T, dtype=torch.float32)

    def test_and_update_ordering(self, epoch):
        """Test different orderings after model has trained"""
        if epoch != 10:  # Only test at epoch 10
            return

        if not hasattr(self, 'groups') or not self.groups:
            return

        logger.info("="*50)
        logger.info(f"TESTING ORDERINGS AT EPOCH {epoch}")
        logger.info("="*50)

        # NOW the model has learned something, so we can test orderings
        ordering_tester = ClusterOrderingTester(self)
        best_ordering, all_scores = ordering_tester.find_best_orderings(
            self.groups,
            n_test=5,
            test_epochs=2
        )

        # Update the discovered best order
        if isinstance(best_ordering, tuple) and len(best_ordering) == 2:
            best_name, best_sequence = best_ordering
            self.discovered_best_order = best_sequence

            # Update feature weights with discovered ordering
            if hasattr(self, 'feature_weights'):
                self.feature_weights['best_ordering'] = best_name
                self.feature_weights['ordering_scores'] = all_scores
                self.feature_weights['ordering_discovered_at_epoch'] = epoch

            logger.info(f"Updated to use {best_name} ordering with score {all_scores.get(best_name, 0):.3f}")

    def _save_clustering_state(self, epoch):
        """Save clustering state for reproducibility"""
        state = {
            'codebook': self.geom_profiler.codebook,
            'prototypes': self.geom_profiler.prototypes,
            'stats_mu': self.clusterer.stats_mu,
            'stats_sigma': self.clusterer.stats_sigma,
            'geom_mu': self.clusterer.geom_mu,
            'geom_sigma': self.clusterer.geom_sigma,
            'pair_to_final': self.clusterer.pair_to_final,
            'T': self.clusterer.T,
            'groups': dict(self.groups),
            'epoch': epoch
        }

        path = self.cfg.experiment_dir / f'clustering_epoch_{epoch}.pkl'
        with open(path, 'wb') as f:
            pickle.dump(state, f)
        logger.info(f"Saved clustering state to {path}")

    def calibrate_bn(self, model, dataloader, steps=16):
        """Calibrate BatchNorm statistics"""
        model.train()

        for i, (x, y) in enumerate(dataloader):
            if i >= steps:
                break

            x, y = x.to(self.device), y.to(self.device)

            # Clean forward
            with torch.no_grad(), UseAdvBN(model, False):
                _ = model(Attacks.normalize(x, self.mean, self.geom_profiler.std))

            # Adversarial forward
            x_adv = Attacks.fgsm(
                model, x, y,
                self.cfg.epsilon,
                self.mean,
                self.geom_profiler.std,
                use_adv_bn=True
            )

            with torch.no_grad(), UseAdvBN(model, True):
                _ = model(Attacks.normalize(x_adv, self.mean, self.geom_profiler.std))

    def run_probe(self, dataloader, groups):
        """Run lightweight probe to evaluate groups"""
        self.model.eval()
        group_results = defaultdict(lambda: {'robust': 0, 'total': 0, 'pgd': 0})

        # Sample a few batches per group
        for group_id, indices in groups.items():
            if not indices:
                continue

            # Sample up to 3 batches from this group
            n_samples = min(3 * self.cfg.batch_size, len(indices))
            sampled_indices = np.random.choice(indices, n_samples, replace=False)

            for i in range(0, n_samples, self.cfg.batch_size):
                batch_indices = sampled_indices[i:i+self.cfg.batch_size]

                x = torch.stack([dataloader.dataset[idx][0] for idx in batch_indices])
                y = torch.tensor([dataloader.dataset[idx][1] for idx in batch_indices])
                x, y = x.to(self.device), y.to(self.device)

                # Quick PGD-3 probe
                x_adv, steps_used = Attacks.pgd(
                    self.model, x, y,
                    self.cfg.epsilon, self.cfg.pgd_step_size, 3,
                    self.mean, self.std,
                    use_adv_bn=True, early_stop=True, min_early_stop_steps=2
                )

                with torch.no_grad(), UseAdvBN(self.model, True):
                    logits = self.model(Attacks.normalize(x_adv, self.mean, self.std))
                    robust = (logits.argmax(1) == y).sum().item()

                group_results[group_id]['robust'] += robust
                group_results[group_id]['total'] += len(batch_indices)
                group_results[group_id]['pgd'] += steps_used * len(batch_indices)

        # Update scheduler with probe results
        for group_id, results in group_results.items():
            if results['total'] > 0:
                robust_acc = results['robust'] / results['total']
                self.scheduler.update_probe_result(group_id, robust_acc, results['pgd'])

        return group_results


    def train_step_with_metrics(self, x, y, cluster_id, step):
        # Generate adversarial examples with adaptive PGD
        batch_start = time.time()
        is_teacher = (self.cfg.mode == "teacher")
        use_adaptive = self.cfg.adaptive_pgd and cluster_id is not None and not is_teacher

        if use_adaptive:
            x_adv, steps_used = Attacks.pgd_adaptive(
                self.model, x, y,
                self.cfg.epsilon,
                self.cfg.pgd_step_size,
                self.cfg.pgd_steps,
                self.mean, self.std,  # Use self.mean/std consistently
                random_start=True,
                use_adv_bn=True,
                cluster_difficulties=self.cluster_difficulties,
                cluster_id=cluster_id,
                difficulty_profiles=getattr(self, 'difficulty_profiles', None),
                current_epoch=self.current_epoch,
                max_epochs=self.cfg.epochs,
                min_early_stop_steps=self.cfg.min_early_stop_steps
            )
        else:
            x_adv, steps_used = Attacks.pgd(
                self.model, x, y,
                self.cfg.epsilon,
                self.cfg.pgd_step_size,
                self.cfg.pgd_steps,
                self.mean, self.std,  # Use self.mean/std consistently
                random_start=True,
                use_adv_bn=True,
                early_stop=self.cfg.early_stop_pgd,
                min_early_stop_steps=self.cfg.min_early_stop_steps
            )

        self.pgd_calls_train += steps_used * x.size(0)
        self.pgd_calls_epoch += steps_used * x.size(0)

        if self.cfg.pgd_budget_total and self.cfg.pgd_budget_mode != "none":
            if self.pgd_calls_train >= self.cfg.pgd_budget_total:
                if self.cfg.pgd_budget_mode == "stop":
                    logger.warning(f"PGD budget exhausted ({self.pgd_calls_train}/{self.cfg.pgd_budget_total})")
                    return {'budget_exhausted': True, 'loss': 0, 'acc_clean': 0, 'acc_robust': 0,
                            'steps_used': 0, 'steps_saved': 0}
                elif self.cfg.pgd_budget_mode == "throttle":
                    budget_remaining = self.cfg.pgd_budget_total - self.pgd_calls_train
                    budget_per_epoch = budget_remaining / max(1, self.cfg.epochs - self.current_epoch)
                    if hasattr(self, 'batches_per_epoch') and self.batches_per_epoch:
                        self.cfg.pgd_steps = max(1, min(self.cfg.pgd_steps,
                                                        int(budget_per_epoch / self.batches_per_epoch / self.cfg.batch_size)))

        steps_saved = (self.cfg.pgd_steps - steps_used) * x.size(0)

        self.optimizer.zero_grad()

        # Forward passes
        with UseAdvBN(self.model, False):
            logits_clean = self.model(
                Attacks.normalize(x, self.mean, self.std)
            )

        with UseAdvBN(self.model, True):
            logits_adv = self.model(
                Attacks.normalize(x_adv, self.mean, self.std)
            )

        # Losses
        loss_clean = F.cross_entropy(logits_clean, y)
        beta = self._get_trades_beta()

        p_clean = F.softmax(logits_clean, dim=1)
        p_adv = F.softmax(logits_adv, dim=1)

        if beta > 0:
            loss_kl = F.kl_div(
                F.log_softmax(logits_adv, dim=1),
                p_clean.detach(),
                reduction='batchmean'
            )
            loss = loss_clean + beta * loss_kl
        else:
            loss_kl = torch.tensor(0.0)
            loss = loss_clean

        loss.backward()

        # Gradient statistics before clipping
        total_norm = 0
        for p in self.model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        grad_norm = total_norm ** 0.5

        if self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)

        self.optimizer.step()
        self.global_step += 1
        self._ema_update()

        # Calculate metrics
        with torch.no_grad():
            pred_clean = logits_clean.argmax(1)
            pred_adv = logits_adv.argmax(1)

            acc_clean = (pred_clean == y).float().mean()
            acc_robust = (pred_adv == y).float().mean()

            # Perturbation statistics
            delta = (x_adv - x).view(x.size(0), -1)
            pert_norm = delta.norm(p=2, dim=1).mean()

            # Margin loss
            correct_logits = logits_adv.gather(1, y.unsqueeze(1))
            wrong_logits = logits_adv.clone()
            wrong_logits.scatter_(1, y.unsqueeze(1), float('-inf'))
            wrong_max, _ = wrong_logits.max(dim=1)
            margin = (correct_logits.squeeze() - wrong_max).mean()

            # Success rate of attacks
            attack_success = (pred_adv != y).float().mean()

        # Store batch metrics
        self.batch_metrics['losses'].append(loss.item())
        self.batch_metrics['clean_acc'].append(acc_clean.item())
        self.batch_metrics['robust_acc'].append(acc_robust.item())
        self.batch_metrics['pgd_steps'].append(steps_used)

        if cluster_id is not None:
            self.pgd_calls_per_cluster[cluster_id] += steps_used * x.size(0)
            self.robust_acc_per_cluster[cluster_id].append(acc_robust.item())

        if self.global_step % self.cfg.log_interval == 0:
            self.time_series_log.append({
                'step': self.global_step,
                'epoch': self.current_epoch,
                'wall_time': time.time() - self.training_start_wall_time,
                'clean_acc': acc_clean.item(),
                'robust_acc': acc_robust.item(),
                'pgd_calls_cumulative': self.pgd_calls_train,
                'loss': loss.item(),
            })

        # Return comprehensive metrics
        return {
            'loss': loss.item(),
            'loss_clean': loss_clean.item(),
            'loss_kl': loss_kl.item() if isinstance(loss_kl, torch.Tensor) else loss_kl,
            'acc_clean': acc_clean.item(),
            'acc_robust': acc_robust.item(),
            'steps_used': steps_used,
            'steps_saved': steps_saved,
            'grad_norm': grad_norm,
            'pert_norm': pert_norm.item(),
            'margin': margin.item(),
            'attack_success': attack_success.item(),
            'batch_time': time.time() - batch_start
        }


    def train_step(self, x, y, cluster_id):
        """Fixed training step that guarantees adaptive PGD in student mode"""

        # CRITICAL FIX: Ensure valid cluster assignment
        if cluster_id is None or cluster_id >= self.cfg.n_clusters:
            if hasattr(self, 'is_student_mode') and self.is_student_mode:
                cluster_id = self._predict_cluster_robust(x, y)
                logger.debug(f"Student: assigned robust cluster {cluster_id}")
            else:
                cluster_id = 0  # Fallback for teacher mode

        # CRITICAL FIX: Always use adaptive PGD in student mode
        is_student = hasattr(self, 'is_student_mode') and self.is_student_mode
        use_adaptive = is_student

        if use_adaptive and cluster_id is not None:
            # Apply student difficulty scaling
            if is_student and hasattr(self, 'student_difficulty_scale'):
                temp_difficulties = {}
                for cid, diff in self.cluster_difficulties.items():
                    temp_difficulties[cid] = min(0.95, diff * self.student_difficulty_scale)
            else:
                temp_difficulties = self.cluster_difficulties

            x_adv, steps_used = Attacks.pgd_adaptive(
                self.model, x, y,
                self.cfg.epsilon,
                self.cfg.pgd_step_size,
                self.cfg.pgd_steps,
                self.mean, self.std,
                random_start=True,
                use_adv_bn=True,
                cluster_difficulties=temp_difficulties,
                cluster_id=cluster_id,
                difficulty_profiles=getattr(self, 'difficulty_profiles', None),
                current_epoch=self.current_epoch,
                max_epochs=self.cfg.epochs,
                min_early_stop_steps=self.cfg.min_early_stop_steps,
                is_student_mode=is_student  # Pass student mode flag
            )
        else:
            # Standard PGD for teacher (no early stopping)
            x_adv, steps_used = Attacks.pgd(
                self.model, x, y,
                self.cfg.epsilon,
                self.cfg.pgd_step_size,
                self.cfg.pgd_steps,
                self.mean, self.std,
                random_start=True,
                use_adv_bn=True,
                early_stop=is_student,  # Only early stop for student
                min_early_stop_steps=self.cfg.min_early_stop_steps if is_student else self.cfg.pgd_steps
            )

        # Track student performance for online adaptation
        if is_student and hasattr(self, 'student_cluster_performance'):
            self.student_cluster_performance[cluster_id].append({
                'steps_used': steps_used,
                'epoch': self.current_epoch
            })

        # Rest of your existing train_step code unchanged...
        self.optimizer.zero_grad()

        with UseAdvBN(self.model, False):
            logits_clean = self.model(Attacks.normalize(x, self.mean, self.std))

        with UseAdvBN(self.model, True):
            logits_adv = self.model(Attacks.normalize(x_adv, self.mean, self.std))

        loss_clean = F.cross_entropy(logits_clean, y)
        beta = self._get_trades_beta()

        if beta > 0:
            p_clean = F.softmax(logits_clean, dim=1).detach()
            loss_kl = F.kl_div(
                F.log_softmax(logits_adv, dim=1),
                p_clean,
                reduction='batchmean'
            )
            loss = loss_clean + beta * loss_kl
        else:
            loss = loss_clean

        loss.backward()

        if self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)

        self.optimizer.step()
        self._ema_update()

        with torch.no_grad():
            acc_clean = (logits_clean.argmax(1) == y).float().mean().item()
            acc_robust = (logits_adv.argmax(1) == y).float().mean().item()

        return {
            'loss': loss.item(),
            'acc_clean': acc_clean,
            'acc_robust': acc_robust,
            'steps_used': steps_used,
            'cluster_id': cluster_id,
            'is_adaptive': use_adaptive
        }




    def _ensure_group_buckets(self, gid):
        """Ensure group-specific metric containers exist"""
        d = self.detailed_metrics
        if 'group_loss' not in d:
            d['group_loss'] = defaultdict(list)
        if 'group_robust_acc' not in d:
            d['group_robust_acc'] = defaultdict(list)
        if 'group_rewards' not in d:
            d['group_rewards'] = defaultdict(list)

    def get_uniform_order(self, groups):
        """Generate uniform round-robin order with proper batch count"""
        if not groups:
            return []

        n_samples = sum(len(v) for v in groups.values())
        n_batches = math.ceil(n_samples / self.cfg.batch_size)

        # Simple round-robin
        order = []
        group_ids = sorted(groups.keys())
        for i in range(n_batches):
            order.append(group_ids[i % len(group_ids)])

        return order

    def train_student_adaptive(self, train_loader, val_loader, test_loader):
        """Student-specific training with online adaptation"""

        # Skip AE training - use teacher's weights
        logger.info("Student mode: Using teacher's autoencoder weights")

        # Validate student setup
        if not hasattr(self, 'is_student_mode') or not self.is_student_mode:
            raise ValueError("Student mode not properly initialized")

        if not hasattr(self, 'difficulty_profiles') or not self.difficulty_profiles:
            raise ValueError("No difficulty profiles loaded from teacher")

        # Set loaders
        self.train_loader = train_loader
        self.train_set = train_loader.dataset
        self.val_loader = val_loader

        # Initialize tracking
        self.pgd_calls_train = 0
        self.pgd_calls_eval = 0
        self.pgd_calls_epoch = 0
        self.training_start_time = time.time()
        self.training_start_wall_time = time.time()

        # Student-specific tracking
        self.student_epoch_metrics = []

        for epoch in range(self.cfg.epochs):
            self.current_epoch = epoch
            epoch_start = time.time()

            logger.info(f"\nStudent Epoch {epoch+1}/{self.cfg.epochs}")
                        
            if epoch > 0:
                self._update_student_performance()

            # Build order ONCE
            order = self._build_student_order(epoch)
            logger.info(f"Original order (first 20): {order[:20]}")
            
            # Debug the groups
            logger.info(f"DEBUG: self.groups keys = {list(self.groups.keys())}")
            
            groups = self.groups
            if hasattr(self.cfg, 'excluded_cluster') and self.cfg.excluded_cluster is not None:
                excluded = self.cfg.excluded_cluster
                logger.info(f"Attempting to exclude cluster {excluded}, groups keys: {list(groups.keys())}")
                
                # Make sure we're comparing the same types
                if excluded in groups:
                    excluded_size = len(groups[excluded])
                    groups = {k: v for k, v in groups.items() if k != excluded}
                    
                    logger.info(f"Excluded cluster {excluded} ({excluded_size} samples) from training")
                    
                    # Filter the order
                    valid_clusters = list(groups.keys())
                    filtered_order = []
                    for c in order:
                        if c == excluded:
                            filtered_order.append(np.random.choice(valid_clusters))
                        else:
                            filtered_order.append(c)
                    order = filtered_order

            # Create ordered loader
            train_loader_ordered = DataLoader(
                self.train_set,
                batch_sampler=GroupBatchSampler(
                    groups, order, self.cfg.batch_size,
                    self.train_set, drop_small=False
                ),
                num_workers=0,
                pin_memory=False
            )

            logger.info(f"Created DataLoader with {len(train_loader_ordered)} batches")
            logger.info(f"Starting training loop...")

            # ADD THE DIAGNOSTIC CODE HERE:
            # Diagnostic logging to verify counting
            if self.cfg.excluded_cluster is not None:
                excluded = self.cfg.excluded_cluster
                logger.info(f"=== DIAGNOSTIC: Cluster Exclusion Analysis ===")
                logger.info(f"Excluded cluster {excluded} had {len(self.groups[excluded])} samples")
                logger.info(f"Remaining training samples: {sum(len(g) for g in groups.values())}")
                logger.info(f"Expected batches (math): {math.ceil(sum(len(g) for g in groups.values()) / self.cfg.batch_size)}")
                logger.info(f"Actual batches in DataLoader: {len(train_loader_ordered)}")
                logger.info(f"Order length: {len(order)}")
                logger.info(f"Active clusters: {list(groups.keys())}")

            # Initialize cluster tracking
            cluster_usage_tracker = defaultdict(int)

            # Then continue with the existing training loop:
            epoch_metrics = defaultdict(list)
            adaptive_count = 0
            total_steps = 0

            self.model.train()

            for step, (x, y) in enumerate(train_loader_ordered):
                x, y = x.to(self.device), y.to(self.device)

                # Get cluster with robust prediction
                cluster_id = order[step] if step < len(order) else None
                cluster_usage_tracker[cluster_id] += 1
                if cluster_id is None or cluster_id >= self.cfg.n_clusters:
                    cluster_id = self._predict_cluster_robust(x, y)

                # Use enhanced training step for students
                metrics = self._train_step_student(x, y, cluster_id)

                if metrics.get('is_adaptive', False):
                    adaptive_count += 1
                total_steps += 1

                # Accumulate metrics
                for k, v in metrics.items():
                    if isinstance(v, (int, float)):
                        epoch_metrics[k].append(v)

                # Logging
                if step % 50 == 0:
                    steps_used = metrics.get('steps_used', self.cfg.pgd_steps)
                    robust_acc = metrics.get('acc_robust', 0)
                    logger.info(
                        f"[{epoch+1}/{self.cfg.epochs}][{step}/{len(train_loader_ordered)}] "
                        f"Robust: {robust_acc:.3f} | "
                        f"Steps: {steps_used}/{self.cfg.pgd_steps} | "
                        f"Cluster: {cluster_id} | "
                        f"Adaptive: {adaptive_count}/{total_steps}"
                    )

            # Epoch summary
            epoch_time = time.time() - epoch_start
            adaptive_ratio = adaptive_count / max(1, total_steps)

            if self.cfg.excluded_cluster is not None:
                logger.info(f"=== EPOCH {epoch+1} CLUSTER USAGE ===")
                logger.info(f"Cluster usage counts: {dict(cluster_usage_tracker)}")
                total_batches = sum(cluster_usage_tracker.values())
                logger.info(f"Total batches processed: {total_batches}")
                
                # Calculate actual PGD calls
                actual_pgd_calls_epoch = self.pgd_calls_epoch
                expected_pgd_calls = total_batches * self.cfg.batch_size * self.cfg.pgd_steps
                logger.info(f"Actual PGD calls this epoch: {actual_pgd_calls_epoch}")
                logger.info(f"Expected PGD calls (if no early stop): {expected_pgd_calls}")
                logger.info(f"PGD efficiency: {actual_pgd_calls_epoch / max(1, expected_pgd_calls):.2%}")

            avg_metrics = {}
            for k, v in epoch_metrics.items():
                if v and isinstance(v[0], (int, float)):
                    avg_metrics[k] = np.mean(v)

            logger.info(f"\nStudent Epoch {epoch+1} Summary:")
            logger.info(f"  Robust Acc: {avg_metrics.get('acc_robust', 0):.4f}")
            logger.info(f"  Avg Steps: {avg_metrics.get('steps_used', 0):.1f}")
            logger.info(f"  Adaptive Ratio: {adaptive_ratio:.2f}")
            logger.info(f"  Epoch Time: {epoch_time:.1f}s")

            # Store epoch metrics
            self.student_epoch_metrics.append({
                'epoch': epoch + 1,
                'adaptive_ratio': adaptive_ratio,
                'avg_robust': avg_metrics.get('acc_robust', 0),
                'avg_steps': avg_metrics.get('steps_used', 0),
                'epoch_time': epoch_time
            })

            # Validation
            if (epoch + 1) % self.cfg.eval_interval == 0:
                val_metrics = self.evaluate(val_loader)
                logger.info(f"  Val Robust: {val_metrics['robust']:.4f}")

            # Learning rate update
            if epoch >= self.cfg.warmup_epochs and hasattr(self, 'lr_scheduler'):
                self.lr_scheduler.step()

            if epoch >= self.cfg.warmup_epochs and hasattr(self, 'lr_scheduler'):
                self.lr_scheduler.step()

            # Test and update ordering after sufficient training
            if epoch == 10 and self.cfg.cluster_feature_type == "adaptive_comprehensive":
                self.test_and_update_ordering(epoch)

            # ADD THIS BLOCK HERE (before GPU cache clearing):
            # Two-phase teacher training
            if self.cfg.mode == "teacher" and epoch == 20:
                self.finalize_clusters()
                self.cfg.discovery_interval = 999  # Stop re-discovering
                logger.info("Switching to exploitation phase with finalized clusters")

            # Clear GPU cache at end of epoch for memory hygiene
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        logger.info("Student training completed")
        return self.student_epoch_metrics

    def _train_step_student(self, x, y, cluster_id):
        """Enhanced training step specifically for student mode"""

        # Ensure valid cluster assignment
        if cluster_id is None or cluster_id >= self.cfg.n_clusters:
            cluster_id = self._predict_cluster_robust(x, y)

        # Student always uses adaptive PGD
        use_adaptive = True

        # Determine conceptual role and context
        conceptual_context = {}
        adaptive_min_steps = self.cfg.min_early_stop_steps  # Default

        if hasattr(self, 'T') and self.T is not None and cluster_id < len(self.T):
            # Check if we're on a strong discovered path
            if hasattr(self, 'prev_cluster') and self.prev_cluster is not None:
                transition_strength = self.T[self.prev_cluster, cluster_id].item()

                # CRITICAL FIX: More aggressive step adaptation based on transition strength
                if transition_strength > 0.3:  # Strong path like 0->1 at 0.321
                    adaptive_min_steps = 2  # Very aggressive - can stop at 2
                    logger.debug(f"Strong transition {self.prev_cluster}->{cluster_id} (strength={transition_strength:.3f}), min_steps=2")
                elif transition_strength > 0.15:  # Moderate strength
                    adaptive_min_steps = 3  # Moderate
                    logger.debug(f"Moderate transition {self.prev_cluster}->{cluster_id} (strength={transition_strength:.3f}), min_steps=3")
                elif transition_strength < 0.05:  # Weak/redundant
                    adaptive_min_steps = 1  # Skip almost entirely
                    logger.debug(f"Weak transition {self.prev_cluster}->{cluster_id} (strength={transition_strength:.3f}), min_steps=1")
                else:
                    adaptive_min_steps = 3  # Default for normal transitions

            # Compute cluster's role in conceptual graph
            out_degree = self.T[cluster_id].sum().item()
            in_degree = self.T[:, cluster_id].sum().item()

            # Determine role and base difficulty
            if out_degree > 0.3 and in_degree < 0.2:
                role = "foundation"
                base_scale = 0.7
            elif in_degree > 0.3 and out_degree < 0.2:
                role = "advanced"
                base_scale = 1.0
            elif out_degree > 0.2 and in_degree > 0.2:
                role = "bridge"
                base_scale = 0.85
            else:
                role = "isolated"
                base_scale = 0.9

            # Check if we're on a strong discovered path
            if hasattr(self, 'prev_cluster') and self.prev_cluster is not None:
                transition_strength = self.T[self.prev_cluster, cluster_id].item()
                if transition_strength > self.T.mean().item() * 1.5:
                    # On a strong path - model is conceptually prepared
                    base_scale = min(1.0, base_scale * 1.1)

            # Create conceptual difficulty profiles (reinterpret existing profiles)
            for cid in range(len(self.T)):
                if cid in self.difficulty_profiles:
                    # Reinterpret difficulty as "conceptual complexity"
                    original_diff = self.difficulty_profiles[cid].get('overall_difficulty', 0.5)

                    # Adjust based on conceptual role
                    cid_out = self.T[cid].sum().item()
                    cid_in = self.T[:, cid].sum().item()

                    # Foundation clusters need consistent training
                    if cid_out > 0.3:
                        conceptual_context[cid] = 0.6 + original_diff * 0.2
                    # Advanced clusters can handle more when prepared
                    elif cid_in > 0.3:
                        conceptual_context[cid] = 0.5 + original_diff * 0.5
                    else:
                        conceptual_context[cid] = original_diff
                else:
                    conceptual_context[cid] = 0.5
        else:
            # Fallback if no T matrix
            conceptual_context = self.cluster_difficulties if hasattr(self, 'cluster_difficulties') else {i: 0.5 for i in range(self.cfg.n_clusters)}
            base_scale = 1.0

        # Still use adaptive PGD but with conceptual context
        x_adv, steps_used = Attacks.pgd_adaptive(
            self.model, x, y,
            self.cfg.epsilon,
            self.cfg.pgd_step_size,
            self.cfg.pgd_steps,
            self.mean, self.std,
            random_start=True,
            use_adv_bn=True,
            cluster_difficulties=conceptual_context,  # Pass reinterpreted difficulties
            cluster_id=cluster_id,
            difficulty_profiles=self.difficulty_profiles,  # Keep for attack statistics
            current_epoch=self.current_epoch,
            max_epochs=self.cfg.epochs,
            min_early_stop_steps=adaptive_min_steps,  # USE THE ADAPTIVE VALUE
            is_student_mode=True
        )

        if not hasattr(self, 'pgd_calls_train'):
            self.pgd_calls_train = 0
        if not hasattr(self, 'pgd_calls_epoch'):
            self.pgd_calls_epoch = 0
        self.pgd_calls_train += steps_used * x.size(0)
        self.pgd_calls_epoch += steps_used * x.size(0)

        # Track which cluster we came from for next iteration
        self.prev_cluster = cluster_id

        # Track student performance
        self.student_cluster_performance[cluster_id].append({
            'steps_used': steps_used,
            'epoch': self.current_epoch
        })

        # Standard training step
        self.optimizer.zero_grad()

        with UseAdvBN(self.model, False):
            logits_clean = self.model(Attacks.normalize(x, self.mean, self.std))

        with UseAdvBN(self.model, True):
            logits_adv = self.model(Attacks.normalize(x_adv, self.mean, self.std))

        loss_clean = F.cross_entropy(logits_clean, y)
        beta = self._get_trades_beta()

        if beta > 0:
            p_clean = F.softmax(logits_clean, dim=1).detach()
            loss_kl = F.kl_div(
                F.log_softmax(logits_adv, dim=1),
                p_clean,
                reduction='batchmean'
            )
            loss = loss_clean + beta * loss_kl
        else:
            loss = loss_clean

        loss.backward()

        if self.cfg.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)

        self.optimizer.step()
        self._ema_update()

        # Calculate metrics
        with torch.no_grad():
            acc_clean = (logits_clean.argmax(1) == y).float().mean().item()
            acc_robust = (logits_adv.argmax(1) == y).float().mean().item()

            # Update student performance tracking
            if len(self.student_cluster_performance[cluster_id]) > 0:
                self.student_cluster_performance[cluster_id][-1]['acc_robust'] = acc_robust

        return {
            'loss': loss.item(),
            'acc_clean': acc_clean,
            'acc_robust': acc_robust,
            'steps_used': steps_used,
            'cluster_id': cluster_id,
            'is_adaptive': use_adaptive
        }


    def _build_student_order(self, epoch):
        n_batches = math.ceil(sum(len(v) for v in self.groups.values()) / self.cfg.batch_size)

        logger.info(f"DEBUG: order_mode = '{self.cfg.order_mode}', type = {type(self.cfg.order_mode)}")

        # Initialize order variable
        order = None

        # Single cluster override
        if hasattr(self.cfg, 'single_cluster_id') and self.cfg.order_mode == "single_cluster":
            cluster_id = self.cfg.single_cluster_id
            logger.info(f"Student forcing single cluster {cluster_id} for all {n_batches} batches")
            order = [cluster_id] * n_batches

        # CUSTOM ORDER SUPPORT
        elif self.cfg.order_mode == "custom" and hasattr(self.cfg, 'custom_order') and self.cfg.custom_order:
            order_sequence = [int(x) for x in self.cfg.custom_order.split(',')]
            order = []
            for i in range(n_batches):
                order.append(order_sequence[i % len(order_sequence)])
            logger.info(f"Using custom order: {order_sequence} repeated for {n_batches} batches")

        # TRUE UNIFORM RANDOM
        elif self.cfg.order_mode == "uniform":
            n_clusters = len(self.groups)
            order = [np.random.randint(n_clusters) for _ in range(n_batches)]
            logger.info(f"True uniform random ordering across {n_clusters} clusters")

        # SEQUENTIAL (0,1,2,3,4,0,1,2,3,4,...)
        elif self.cfg.order_mode == "sequential":
            n_clusters = len(self.groups)
            order = []
            for i in range(n_batches):
                order.append(i % n_clusters)
            logger.info(f"Sequential ordering: cycling through clusters 0-{n_clusters-1}")

        # CYCLE PER EPOCH
        elif self.cfg.order_mode == "cycle_per_epoch":
            n_clusters = len(self.groups)
            cluster_for_epoch = epoch % n_clusters
            order = [cluster_for_epoch] * n_batches
            logger.info(f"Cycle per epoch: Epoch {epoch} uses only cluster {cluster_for_epoch}")

        # LOAT MODE - Natural progression with T matrix
        elif self.cfg.order_mode == "loat":
            n_clusters = len(self.groups)
            n_samples = sum(len(v) for v in self.groups.values())
            n_batches = math.ceil(n_samples / self.cfg.batch_size)
            
            if hasattr(self, 'T') and self.T is not None:
                T_np = self.T.cpu().numpy() if hasattr(self.T, 'cpu') else self.T

                # Calculate self-loop probability based on T matrix
                diagonal_strength = np.diag(T_np).mean()
                off_diagonal_strength = (T_np.sum() - np.diag(T_np).sum()) / (T_np.size - len(np.diag(T_np)))

                if diagonal_strength > off_diagonal_strength:
                    self_loop_prob = min(0.85, diagonal_strength)
                else:
                    self_loop_prob = 0.3

                order = []

                # Start from foundation cluster
                out_degrees = T_np.sum(axis=1)
                current = int(np.argmax(out_degrees))

                consecutive_count = 0
                max_consecutive = 10

                for i in range(n_batches):
                    order.append(current)
                    consecutive_count += 1

                    if i < n_batches - 1:
                        if consecutive_count < max_consecutive and np.random.random() < self_loop_prob:
                            next_cluster = current
                        else:
                            T_row = T_np[current].copy()
                            T_row[current] = 0

                            if T_row.sum() > 0:
                                T_row = T_row / T_row.sum()
                                next_cluster = np.random.choice(len(T_row), p=T_row)
                            else:
                                available = [c for c in range(n_clusters) if c != current]
                                next_cluster = np.random.choice(available) if available else current

                            consecutive_count = 0

                        current = next_cluster

                unique_clusters = len(set(order))
                transitions = sum(1 for i in range(1, len(order)) if order[i] != order[i-1])
            else:
                # No T matrix yet, fallback to uniform
                logger.info("LOAT mode but no T-matrix, using uniform ordering")
                order = [np.random.randint(len(self.groups)) for _ in range(n_batches)]

        # UCB MODE
        elif self.cfg.order_mode == "ucb":
            # This would use your scheduler logic
            logger.info("UCB ordering not implemented in student mode, using uniform")
            order = [np.random.randint(len(self.groups)) for _ in range(n_batches)]

        # DEFAULT FALLBACK
        else:
            logger.warning(f"Unknown order_mode '{self.cfg.order_mode}', using uniform")
            n_clusters = len(self.groups)
            order = [np.random.randint(n_clusters) for _ in range(n_batches)]

        # FILTER EXCLUDED CLUSTER - Only if exclusion is enabled
        if hasattr(self.cfg, 'excluded_cluster') and self.cfg.excluded_cluster is not None:
            valid_clusters = [k for k in self.groups.keys() if k != self.cfg.excluded_cluster]
            if valid_clusters:
                filtered_order = []
                for cluster_id in order:
                    if cluster_id == self.cfg.excluded_cluster:
                        # Replace excluded cluster with random valid cluster
                        filtered_order.append(np.random.choice(valid_clusters))
                    else:
                        filtered_order.append(cluster_id)
                order = filtered_order

        return order

    def _get_hardest_cluster(self):
        """Get cluster with highest difficulty"""
        if not hasattr(self, 'difficulty_profiles'):
            return 0

        hardest_cluster = 0
        max_difficulty = 0

        for cid, profile in self.difficulty_profiles.items():
            diff = profile.get('overall_difficulty', 0.5)
            if diff > max_difficulty:
                max_difficulty = diff
                hardest_cluster = cid

        return hardest_cluster

    def _update_student_performance(self):
        """Update student difficulty scaling based on recent performance"""
        if not hasattr(self, 'student_cluster_performance'):
            return

        total_efficiency = []

        for cluster_id, performance_list in self.student_cluster_performance.items():
            if len(performance_list) < 5:
                continue

            recent_perf = performance_list[-10:]  # Last 10 samples

            # Calculate efficiency metrics
            robust_accs = [p.get('acc_robust', 0.5) for p in recent_perf if 'acc_robust' in p]
            steps_used = [p.get('steps_used', self.cfg.pgd_steps) for p in recent_perf]

            if robust_accs and steps_used:
                avg_robust = np.mean(robust_accs)
                avg_steps = np.mean(steps_used)
                efficiency = avg_robust / max(1e-6, avg_steps / 10)
                total_efficiency.append(efficiency)

        if total_efficiency:
            overall_efficiency = np.mean(total_efficiency)
            target_efficiency = 0.35  # Target efficiency for student

            # Adapt difficulty scaling
            if overall_efficiency > target_efficiency * 1.15:
                # Too easy, increase difficulty
                self.student_difficulty_scale = min(1.4, self.student_difficulty_scale * 1.05)
                logger.debug(f"Increasing difficulty scale to {self.student_difficulty_scale:.3f}")
            elif overall_efficiency < target_efficiency * 0.85:
                # Too hard, decrease difficulty
                self.student_difficulty_scale = max(0.9, self.student_difficulty_scale * 0.95)
                logger.debug(f"Decreasing difficulty scale to {self.student_difficulty_scale:.3f}")



    def train(self, train_loader=None, val_loader=None, test_loader=None):
        """Main training loop with periodic validation"""
        # Set loaders
        if train_loader is not None:
            self.train_loader = train_loader
            self.train_set = train_loader.dataset
        if val_loader is not None:
            self.val_loader = val_loader

        # Ensure batch_metrics exists for logging
        if not hasattr(self, 'batch_metrics'):
            from collections import deque
            self.batch_metrics = {
                'losses': deque(maxlen=200),
                'clean_acc': deque(maxlen=200),
                'robust_acc': deque(maxlen=200),
                'pgd_steps': deque(maxlen=200)
            }

        # Initialize means/std if not already set
        if not hasattr(self, "mean") or not hasattr(self, "std"):
            self.mean = torch.tensor(CIFAR10_MEAN).view(1, 3, 1, 1).to(self.device)
            self.std = torch.tensor(CIFAR10_STD).view(1, 3, 1, 1).to(self.device)

        # EMA alias (for reference)
        # self.ema = getattr(self, "model_ema", None)

        # Guard against AE double-run
        ran_ae = False
        if getattr(self.cfg, "ae_train_epochs", 0) > 0 or getattr(self.cfg, "ae_epochs", 0) > 0:
            if hasattr(self, "train_ae"):
                self.train_ae()
                ran_ae = True
            elif hasattr(self, "_train_ae_warmup"):
                self._train_ae_warmup()
                ran_ae = True

        # Only run second AE path if first didn't run
        if self.cfg.ae_train_epochs > 0 and not self.cfg.use_recipe:
            self.train_autoencoder(self.train_loader)

        # Initialize tracking
        self.pgd_calls_train = 0
        self.pgd_calls_eval = 0
        self.training_start_time = time.time()
        self.training_start_wall_time = time.time()

        # Initialize transition scheduler
        if not hasattr(self, 'transition_scheduler'):
            self.transition_scheduler = TransitionScheduler(self.cfg.n_clusters, self.cfg)
            logger.info(f"Initialized transition scheduler for {self.cfg.n_clusters} clusters")


        for epoch in range(self.cfg.epochs):
            self.epoch_start_time = time.time()
            self.current_epoch = epoch

            # Control when T matrix recording is active (don't recreate scheduler)
            self.transition_scheduler.set_recording_mode(self.cfg.t_matrix_mode, epoch, self.cfg)

            if self.transition_scheduler.recording_enabled:
                logger.info(f"Epoch {epoch}: T matrix recording ENABLED (mode: {self.cfg.t_matrix_mode})")
            else:
                logger.debug(f"Epoch {epoch}: T matrix recording disabled (waiting for epoch {self.cfg.t_matrix_start_epoch})")

            if self.transition_scheduler.recording_enabled:
                logger.info(f"Epoch {epoch}: T matrix recording ENABLED (mode: {self.cfg.t_matrix_mode})")
            else:
                logger.debug(f"Epoch {epoch}: T matrix recording disabled (waiting for epoch {self.cfg.t_matrix_start_epoch})")



            epoch_metrics = defaultdict(list)

            print(f"\n{'='*60}")
            print(f"Epoch {epoch+1}/{self.cfg.epochs}")
            print(f"{'='*60}")

            # Discovery phase
            if epoch % self.cfg.discovery_interval == 0 and not self.cfg.use_recipe:
                discovery_start = time.time()

                try:
                    if self.cfg.mode == "teacher" and self.cfg.cluster_feature_type == "adaptive_comprehensive":
                        # Use the sophisticated multi-feature discovery
                        result = self.comprehensive_discovery_with_learning()

                        # Validate the return value
                        if result is None or len(result) != 4:
                            raise ValueError(f"comprehensive_discovery_with_learning returned invalid result: {result}")

                        groups, labels, T, feature_weights = result

                        # Store the learned weights for recipe
                        self.feature_weights = feature_weights
                        self.discovered_feature_mode = feature_weights.get('feature_combination_mode', 'multi_view')

                        # Extract additional metadata
                        Xs = torch.zeros(len(labels), 10)  # Placeholder since we don't need these separately
                        Xg = torch.zeros(len(labels), 10)

                    elif self.cfg.mode == "teacher" and hasattr(self, 'comprehensive_discovery_pass'):
                        groups, labels, T, Xs, Xg = self.comprehensive_discovery_pass()
                    else:
                        groups, labels, T, Xs, Xg = self.discovery_pass()

                    # CRITICAL: Validate groups before proceeding
                    if not groups or len(groups) == 0:
                        raise ValueError("Discovery returned no groups")

                    if all(len(v) == 0 for v in groups.values()):
                        raise ValueError("All groups are empty")

                    self.groups = groups
                    self.labels = labels
                    self.T = T

                    # Log success
                    logger.info(f"Discovery successful: {len(groups)} groups with sizes {[len(v) for v in groups.values()]}")

                    # [KEEP ALL YOUR EXISTING LOGIC BELOW - difficulty profiles, clusterer fitting, metrics, etc.]
                    if not hasattr(self, 'difficulty_profiles') or not self.difficulty_profiles:
                        self.difficulty_profiles = {}
                        for cluster_id in self.groups.keys():
                            self.difficulty_profiles[cluster_id] = {
                                'overall_difficulty': 0.5,  # Start neutral
                                'asr': 0.5,
                                'margin': 1.0,
                                'grad_complexity': 1.0
                            }
                        logger.info(f"Initialized {len(self.difficulty_profiles)} difficulty profiles")

                    # Check if clusterer exists and is initialized properly
                    if not hasattr(self, 'clusterer'):
                        self.clusterer = MultiViewClusterer(self.cfg, mode=self.cfg.mv_mode)

                    # Fit if not already fitted
                    if getattr(self.clusterer, 'stats_mu', None) is None:
                        if self.cfg.cluster_feature_type != "adaptive_comprehensive":
                            self.clusterer.fit(Xs, Xg)
                        else:
                            # Set minimal scalers for adaptive_comprehensive
                            self.clusterer.stats_mu = torch.zeros(10).to(self.device)
                            self.clusterer.stats_sigma = torch.ones(10).to(self.device)
                            self.clusterer.geom_mu = torch.zeros(10).to(self.device)
                            self.clusterer.geom_sigma = torch.ones(10).to(self.device)

                    # Clustering quality metrics
                    mvc_metrics = self.clusterer.quality_metrics(Xs, Xg)
                    self.detailed_metrics['silhouette_stats'].append(mvc_metrics.get('silhouette_stats', 0))
                    self.detailed_metrics['silhouette_geom'].append(mvc_metrics.get('silhouette_geom', 0))
                    self.detailed_metrics['multiview_agreement'].append(mvc_metrics.get('ARI_stats_geom', 0))
                    self.detailed_metrics['n_clusters_discovered'].append(len(groups))
                    self.detailed_metrics['cluster_sizes'].append([len(v) for v in groups.values()])
                    self.detailed_metrics['transition_matrix'].append(T.cpu().numpy() if T is not None else None)

                    logger.info(f"Discovery took {time.time() - discovery_start:.2f}s")
                    logger.info(f"Clustering quality - Silhouette(stats): {mvc_metrics.get('silhouette_stats', 0):.3f}, "
                              f"Silhouette(geom): {mvc_metrics.get('silhouette_geom', 0):.3f}, "
                              f"View agreement: {mvc_metrics.get('ARI_stats_geom', 0):.3f}")

                    self.update_cluster_difficulties()
                    logger.info(f"Updated cluster difficulties: {self.cluster_difficulties}")

                except Exception as e:
                    logger.error(f"Discovery failed: {e}")
                    import traceback
                    traceback.print_exc()

                    # Enhanced fallback: create emergency groups if none exist
                    if not hasattr(self, 'groups') or self.groups is None:
                        logger.warning("Creating emergency fallback groups to allow training to continue")
                        n_samples = len(self.train_set)
                        n_groups = self.cfg.n_clusters
                        self.groups = {i: list(range(i, n_samples, n_groups)) for i in range(n_groups)}
                        self.labels = torch.zeros(n_samples // self.cfg.batch_size + 1)
                        self.T = torch.eye(n_groups)
                        logger.info(f"Created {n_groups} fallback groups")
                    else:
                        logger.info("Using previous groups as fallback")

            # Verify groups exist and are non-empty
            if not hasattr(self, 'groups') or not self.groups:
                logger.error("No groups available; skipping epoch")
                continue

            # EXCLUDE CLUSTER HERE - before any order building
            if hasattr(self.cfg, 'excluded_cluster') and self.cfg.excluded_cluster is not None:
                if self.cfg.excluded_cluster in self.groups:
                    excluded_size = len(self.groups[self.cfg.excluded_cluster])
                    # Create filtered groups without the excluded cluster
                    groups = {k: v for k, v in self.groups.items() if k != self.cfg.excluded_cluster}
                    logger.info(f"Excluded cluster {self.cfg.excluded_cluster} ({excluded_size} samples) from training")
                else:
                    logger.warning(f"Cluster {self.cfg.excluded_cluster} not found in groups")
                    groups = self.groups
            else:
                groups = self.groups

            # NOW use the filtered 'groups' for everything
            n_batches = math.ceil(sum(len(v) for v in groups.values()) / self.cfg.batch_size)

            # Check if we discovered a best ordering during discovery
            if hasattr(self, 'discovered_best_order') and self.discovered_best_order:
                # Extend discovered order to full epoch length
                best_order = self.discovered_best_order
                while len(best_order) < n_batches:
                    best_order.extend(self.discovered_best_order)
                order = best_order[:n_batches]
                logger.info(f"Epoch {epoch}: Using discovered best ordering")
            elif self.cfg.order_mode == "ucb":
                if epoch < 3:
                    # Initial exploration
                    order = [np.random.randint(len(self.groups)) for _ in range(n_batches)]
                    logger.info(f"Epoch {epoch}: Exploration phase - random ordering")
                elif hasattr(self, 'T') and self.T is not None:
                    # Natural learning with rare fallbacks
                    order = self.scheduler.build_natural_order_with_fallbacks(self.groups, self.T, epoch)
                    logger.info(f"Epoch {epoch}: Natural progression with fallbacks")
                else:
                    # Fallback to UCB if no T matrix yet
                    order = self.scheduler.build_order(self.groups, self.T, epoch)
                    logger.info(f"Epoch {epoch}: UCB ordering (no T-matrix yet)")
            # elif self.cfg.order_mode == "uniform":
            #     order = self.get_uniform_order(self.groups)
            #     logger.info(f"Epoch {epoch}: Uniform round-robin ordering")
            else:
                # Default to uniform
                order = self.get_uniform_order(self.groups)

            # ADD THIS ENTIRE BLOCK - Filter excluded cluster from order
            if hasattr(self.cfg, 'excluded_cluster') and self.cfg.excluded_cluster is not None:
                valid_clusters = list(groups.keys())  # 'groups' already has exclusion applied
                # Replace any occurrence of excluded cluster in the order
                filtered_order = []
                for cluster_id in order:
                    if cluster_id == self.cfg.excluded_cluster:
                        # Replace with random valid cluster
                        filtered_order.append(np.random.choice(valid_clusters))
                    else:
                        filtered_order.append(cluster_id)
                order = filtered_order
                logger.info(f"Filtered order to remove cluster {self.cfg.excluded_cluster}")

            # Safety check for order
            if not order:
                logger.error(f"Empty order at epoch {epoch}")
                continue

            # Initialize block metrics
            self.block_losses = []

            # Create ordered dataloader
            # Detect Windows/WSL and adjust workers accordingly
            import platform

            is_windows = (os.name == 'nt') or (platform.system().lower() == 'windows')
            # is_wsl = False
            # try:
            #     # os.uname() only works on Linux/WSL
            #     is_wsl = 'microsoft' in os.uname().release.lower()
            # except AttributeError:
            #     pass
            is_wsl = False
            try:
                with open('/proc/version', 'r') as f:
                    is_wsl = 'microsoft' in f.read().lower()
            except:
                pass

            num_workers = 0 if (is_windows or is_wsl) else self.cfg.num_workers
            pin_memory = (num_workers > 0)

            filtered_groups = {k: v for k, v in self.groups.items() if len(v) > 0}
            if not filtered_groups:
                logger.error("All groups are empty, skipping epoch")
                continue

            if hasattr(self.cfg, 'excluded_cluster') and self.cfg.excluded_cluster is not None:
                if self.cfg.excluded_cluster in filtered_groups:
                    excluded_size = len(filtered_groups[self.cfg.excluded_cluster])
                    del filtered_groups[self.cfg.excluded_cluster]
                    logger.info(f"Excluded cluster {self.cfg.excluded_cluster} from training ({excluded_size} samples)")


            # Update order to only include valid groups
            valid_group_ids = set(filtered_groups.keys())
            filtered_order = [g for g in order if g in valid_group_ids]
            if not filtered_order:
                logger.error("No valid groups in order, using uniform fallback")
                filtered_order = list(valid_group_ids) * max(1, len(order) // len(valid_group_ids))

            train_loader_ordered = DataLoader(
                self.train_set,
                batch_sampler=GroupBatchSampler(
                    filtered_groups, filtered_order, self.cfg.batch_size,
                    self.train_set, drop_small=False
                ),
                num_workers=num_workers,
                pin_memory=pin_memory
            )

            # Store for budget throttling calculations
            self.batches_per_epoch = len(train_loader_ordered)

            expected_batches = len(order)
            actual_batches = len(train_loader_ordered)
            if len(order) != actual_batches:
                logger.warning(f"Order-sampler mismatch: {len(order)} vs {actual_batches}")
                if len(order) > actual_batches:
                    order = order[:actual_batches]
                else:
                    # Extend order by cycling through available clusters
                    available_clusters = list(self.groups.keys())
                    while len(order) < actual_batches:
                        order.append(available_clusters[len(order) % len(available_clusters)])
            if actual_batches != expected_batches:
                logger.warning(f"Mismatch: order has {expected_batches} but sampler has {actual_batches}")


            # Training
            self.model.train()
            total_batches = len(train_loader_ordered)
            prev_cluster = None  # Initialize for transition tracking

            cluster_sequence = [] 

            for step, (x, y) in enumerate(train_loader_ordered):
                if step >= len(order):
                    logger.warning(f"Step {step} exceeds order length {len(order)}, stopping epoch")
                    break

                try:
                    x = x.to(self.device)
                    y = y.to(self.device)

                    # FIX 1: Get cluster ID correctly
                    if step < len(order):
                        cluster_id = order[step]
                    else:
                        cluster_id = None
                        logger.warning(f"Step {step} beyond order length, no cluster assigned")

                    cluster_sequence.append(cluster_id)

                    # Ensure group buckets exist
                    self._ensure_group_buckets(cluster_id)

                    # Train with adaptive PGD (use your enhanced train_step_with_metrics)
                    if hasattr(self, 'train_step_with_metrics'):
                        metrics = self.train_step_with_metrics(x, y, cluster_id, step)
                    else:
                        result = self.train_step(x, y, cluster_id)
                        if isinstance(result, dict):
                            metrics = result
                        else:
                            # Handle old tuple format
                            metrics = {
                                'loss': result[0] if len(result) > 0 else 0,
                                'acc_clean': result[1] if len(result) > 1 else 0,
                                'acc_robust': result[2] if len(result) > 2 else 0,
                                'steps_used': result[3] if len(result) > 3 else self.cfg.pgd_steps
                            }

                    actual_steps_used = metrics.get('steps_used', self.cfg.pgd_steps)

                    # Log to verify early stopping is working
                    if step % 10 == 0:  # Every 10 steps
                        logger.debug(f"Step {step}: Used {actual_steps_used}/{self.cfg.pgd_steps} PGD steps")

                    # The PGD calls should already be tracked inside train_step
                    # Just verify the cluster attribution
                    if cluster_id is not None:
                        # Already tracked in train_step, just log for verification
                        current_calls = self.pgd_calls_per_cluster[cluster_id]
                        logger.debug(f"Cluster {cluster_id} total PGD calls: {current_calls}")


                    # FIX 7: Record transitions between different clusters only
                    if prev_cluster is not None and cluster_id is not None:
                        # Record the transition (including self-loops for natural learning)
                        self.transition_scheduler.update_transition(
                            prev_cluster, cluster_id,
                            epoch=self.current_epoch
                        )

                        # Track transition reward
                        if not hasattr(self, 'transition_rewards'):
                            self.transition_rewards = defaultdict(list)

                        current_robust = metrics.get('acc_robust', 0)
                        if hasattr(self, 'prev_robust_for_transition'):
                            delta = current_robust - self.prev_robust_for_transition
                            self.transition_rewards[(prev_cluster, cluster_id)].append(delta)
                            # Also update transition scheduler's rewards
                            if not hasattr(self.transition_scheduler, 'transition_rewards'):
                                self.transition_scheduler.transition_rewards = defaultdict(list)
                            self.transition_scheduler.transition_rewards[(prev_cluster, cluster_id)].append(delta)

                        # Update for next transition
                        self.prev_robust_for_transition = current_robust

                        # ADD THIS NEW BLOCK FOR TRANSITION QUALITY:
                        # Track detailed transition quality
                        if not hasattr(self, 'transition_qualities'):
                            self.transition_qualities = defaultdict(list)

                        current_loss = metrics.get('loss', float('inf'))
                        if hasattr(self, 'prev_loss_for_transition'):
                            transition_quality = {
                                'robust_gain': current_robust - self.prev_robust_for_transition,
                                'loss_decrease': self.prev_loss_for_transition - current_loss,
                                'steps_used': metrics.get('steps_used', self.cfg.pgd_steps),
                                'steps_saved': self.cfg.pgd_steps - metrics.get('steps_used', self.cfg.pgd_steps),
                                'epoch': epoch,
                                'batch': step,
                                'success': current_robust > self.prev_robust_for_transition
                            }
                            self.transition_qualities[(prev_cluster, cluster_id)].append(transition_quality)

                            # Double-count successful transitions
                            if transition_quality['success']:
                                self.transition_scheduler.update_transition(prev_cluster, cluster_id, epoch)

                        self.prev_loss_for_transition = current_loss

                        # Only update prev_cluster when actually transitioning
                        prev_cluster = cluster_id
                    else:
                        # First batch - initialize
                        prev_cluster = cluster_id
                        self.prev_robust_for_transition = metrics.get('acc_robust', 0)


                    # Accumulate epoch metrics
                    for k, v in metrics.items():
                        epoch_metrics[k].append(v)

                    # Group-specific metrics
                    self.detailed_metrics['group_loss'][cluster_id].append(metrics['loss'])
                    self.detailed_metrics['group_robust_acc'][cluster_id].append(metrics['acc_robust'])

                    # Log savings specifically for PGD calls
                    if step % 50 == 0:
                        steps_used = metrics.get('steps_used', self.cfg.pgd_steps)
                        is_adaptive = metrics.get('is_adaptive', False)

                        # ANALYZE LAST 50 CLUSTERS
                        recent_clusters = cluster_sequence[-50:] if len(cluster_sequence) >= 50 else cluster_sequence
                        cluster_counts = {}
                        for c in recent_clusters:
                            if c is not None:
                                cluster_counts[c] = cluster_counts.get(c, 0) + 1
                        
                        if cluster_counts:
                            dominant = max(cluster_counts.items(), key=lambda x: x[1])
                            cluster_dist_str = dict(sorted(cluster_counts.items()))
                        else:
                            dominant = (0, 0)
                            cluster_dist_str = {}

                        student_info = " [Student]" if hasattr(self, 'is_student_mode') and self.is_student_mode else ""
                        adaptive_info = f" [Adaptive: {is_adaptive}]" if is_adaptive else " [Standard]"

                        logger.info(
                            f"[{epoch+1}/{self.cfg.epochs}][{step}/{total_batches}] "
                            f"Loss: {metrics['loss']:.4f} | "
                            f"Robust: {metrics['acc_robust']:.3f} | "
                            f"Steps: {steps_used}/{self.cfg.pgd_steps} | "
                            f"Current: {cluster_id} | "
                            f"Last 50: {cluster_dist_str}"
                            f"{student_info}{adaptive_info}"
                        )

                    # FIX 4: Update reward immediately for this cluster
                    if cluster_id is not None and hasattr(self, 'scheduler'):
                        # Calculate immediate reward based on improvement
                        current_robust = metrics.get('acc_robust', 0)

                        # Use per-cluster baseline
                        if not hasattr(self, 'cluster_baseline'):
                            # Use actual number of groups discovered, not just cfg.n_clusters
                            max_cluster_id = max(self.groups.keys()) if self.groups else self.cfg.n_clusters
                            self.cluster_baseline = {k: 0.25 for k in range(max_cluster_id + 1)}

                        # Also add a safety check when accessing
                        if cluster_id not in self.cluster_baseline:
                            self.cluster_baseline[cluster_id] = 0.25
                        baseline = self.cluster_baseline[cluster_id]
                        improvement = current_robust - baseline

                        # Add efficiency component
                        steps_saved_ratio = 1.0 - (metrics.get('steps_used', self.cfg.pgd_steps) / self.cfg.pgd_steps)

                        # Combined reward
                        reward = 0.7 * improvement + 0.3 * steps_saved_ratio

                        # Update scheduler immediately (not waiting for blocks)
                        self.scheduler.rewards[cluster_id].append(reward)
                        self.scheduler.counts[cluster_id] += 1
                        self.scheduler.total_pulls += 1

                        # Update baseline with EMA
                        self.cluster_baseline[cluster_id] = 0.9 * baseline + 0.1 * current_robust

                    # Block rewards (existing code)
                    if hasattr(self, '_accumulate_block'):
                        self._accumulate_block(step, metrics['loss'])

                    if (step + 1) % self.cfg.block_size == 0 and cluster_id is not None:
                        # Enhanced reward incorporating both robustness gain and efficiency
                        recent_metrics = epoch_metrics.get('acc_robust', [])[-self.cfg.block_size:]
                        recent_steps = epoch_metrics.get('steps_used', [])[-self.cfg.block_size:]

                        if recent_metrics and recent_steps:
                            block_robust_mean = np.mean(recent_metrics)
                            block_efficiency = block_robust_mean / max(1e-6, np.mean(recent_steps) / 10)

                            # Baseline from previous block or cluster history
                            baseline_robust = 0.25  # Could use cluster history here

                            reward_dict = {
                                'delta_robust': block_robust_mean - baseline_robust,
                                'efficiency_score': block_efficiency,
                                'combined_reward': 0.7 * (block_robust_mean - baseline_robust) + 0.3 * block_efficiency
                            }

                            if hasattr(self, "scheduler"):
                                self.scheduler.update_reward(cluster_id, reward_dict['combined_reward'])
                                self.detailed_metrics['group_rewards'][cluster_id].append(reward_dict)

                except RuntimeError as e:
                    if "out of memory" in str(e) or "CUDA" in str(e):
                        logger.error(f"GPU error at step {step}: {e}")
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()
                        continue
                    else:
                        raise

            # Epoch summary
            epoch_time = time.time() - self.epoch_start_time

            # Aggregate epoch metrics (with safety checks)
            if epoch_metrics:
                self.detailed_metrics['epoch'].append(epoch + 1)
                self.detailed_metrics['train_loss'].append(np.mean(epoch_metrics['loss']) if epoch_metrics.get('loss') else 0)
                self.detailed_metrics['train_clean_acc'].append(np.mean(epoch_metrics['acc_clean']) if epoch_metrics.get('acc_clean') else 0)
                self.detailed_metrics['train_robust_acc'].append(np.mean(epoch_metrics['acc_robust']) if epoch_metrics.get('acc_robust') else 0)
                self.detailed_metrics['pgd_calls_cumulative'].append(self.pgd_calls_train)
                self.detailed_metrics['pgd_steps_saved'].append(sum(epoch_metrics.get('steps_saved', [0])))
                self.detailed_metrics['time_per_epoch'].append(epoch_time)
                self.detailed_metrics['cumulative_time'].append(time.time() - self.training_start_time)
                self.detailed_metrics['learning_rate'].append(self.optimizer.param_groups[0]['lr'])
                self.detailed_metrics['beta_value'].append(self._get_trades_beta())

                if hasattr(self, 'is_student_mode') and self.is_student_mode and epoch > 0:
                    self.update_student_performance()

                # Log epoch summary
                logger.info(f"\n{'='*40}")
                logger.info(f"Epoch {epoch+1} Summary:")
                logger.info(f"  Train Loss: {np.mean(epoch_metrics.get('loss', [0])):.4f}")
                logger.info(f"  Train Clean Acc: {np.mean(epoch_metrics.get('acc_clean', [0])):.3f}")
                logger.info(f"  Train Robust Acc: {np.mean(epoch_metrics.get('acc_robust', [0])):.3f}")
                logger.info(f"  PGD Steps Saved: {sum(epoch_metrics.get('steps_saved', [0])):,}")
                logger.info(f"  Epoch Time: {epoch_time:.1f}s")
                logger.info(f"  Total PGD Calls: {self.pgd_calls_train:,}")

                if self.robust_acc_per_cluster:
                    for cid in range(self.cfg.n_clusters):
                        if cid in self.robust_acc_per_cluster:
                            recent_robust = self.robust_acc_per_cluster[cid][-10:]
                            if recent_robust:
                                avg_robust = np.mean(recent_robust)
                                self.cluster_difficulties[cid] = 1.0 - avg_robust
                    logger.info(f"Updated difficulties: {self.cluster_difficulties}")

                cluster_visits = defaultdict(int)
                for cid in order[:min(step, len(order))]:
                    cluster_visits[cid] += 1
                logger.info(f"Cluster visits this epoch: {dict(cluster_visits)}")

            # Validation
            if (epoch + 1) % self.cfg.eval_interval == 0:
                val_metrics = self.evaluate(val_loader or self.val_loader)
                self.detailed_metrics['val_clean_acc'].append(val_metrics['clean'])
                self.detailed_metrics['val_robust_acc'].append(val_metrics['robust'])

                # More robust efficiency calculation
                efficiency = val_metrics['robust'] / max(1e-6, self.pgd_calls_train / 1e5)
                self.detailed_metrics['efficiency_score'].append(efficiency)

                logger.info(f"\nValidation Results:")
                logger.info(f"  Clean Accuracy: {val_metrics['clean']:.4f}")
                logger.info(f"  Robust Accuracy: {val_metrics['robust']:.4f}")
                logger.info(f"  Efficiency Score: {efficiency:.4f}")

                # Test periodically
                if test_loader is not None and (epoch + 1) % 5 == 0:
                    test_metrics = self.evaluate(test_loader)
                    self.detailed_metrics['test_clean_acc'].append(test_metrics['clean'])
                    self.detailed_metrics['test_robust_acc'].append(test_metrics['robust'])

                    logger.info(f"\nTest Results:")
                    logger.info(f"  Clean Accuracy: {test_metrics['clean']:.4f}")
                    logger.info(f"  Robust Accuracy: {test_metrics['robust']:.4f}")

                    if self.cfg.run_autoattack and (epoch + 1) % self.cfg.autoattack_freq == 0:
                        logger.info("Running AutoAttack evaluation...")
                        aa_metrics = self.evaluate_autoattack(test_loader)
                        self.detailed_metrics['aa_robust_acc'].append(aa_metrics['aa_robust'])
                        logger.info(f"  AutoAttack Robust: {aa_metrics['aa_robust']:.4f}")

            if self.pgd_calls_per_cluster:
                logger.info("\nPer-Cluster Efficiency:")
                for k in sorted(self.pgd_calls_per_cluster.keys()):
                    calls = self.pgd_calls_per_cluster[k]
                    robust = np.mean(self.robust_acc_per_cluster[k]) if self.robust_acc_per_cluster[k] else 0
                    efficiency = robust / max(1e-6, calls / 1e5)
                    logger.info(f"  Cluster {k}: Robust={robust:.3f}, PGD calls={calls:,}, Efficiency={efficiency:.3f}")


            # Save detailed metrics
            if (epoch + 1) % self.cfg.save_interval == 0:
                self._save_checkpoint(epoch + 1)
                self.save_detailed_metrics(epoch + 1)

            # Update learning rate (ensure scheduler exists)
            if epoch >= self.cfg.warmup_epochs and hasattr(self, 'lr_scheduler'):
                self.lr_scheduler.step()

            # Clear GPU cache at end of epoch for memory hygiene
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        logger.info(f"Saving final checkpoint at epoch {self.cfg.epochs}")
        self._save_checkpoint(self.cfg.epochs)

        if hasattr(self, 'transition_scheduler'):
            T_final = self.finalize_transition_matrix()
            logger.info("="*50)
            logger.info("Final Transition Matrix Analysis:")
            logger.info(f"Shape: {T_final.shape}")
            logger.info(f"Total transitions recorded: {int(self.transition_scheduler.counts.sum())}")

            # Save final T matrix
            t_matrix_path = self.cfg.experiment_dir / 'final_transition_matrix.npy'
            np.save(t_matrix_path, T_final)
            logger.info(f"Saved final transition matrix to {t_matrix_path}")
            logger.info("="*50)

        return self.detailed_metrics


    def save_detailed_metrics(self, epoch):
        """Save comprehensive metrics for paper analysis"""
        metrics_path = self.cfg.experiment_dir / f'detailed_metrics_epoch_{epoch}.pkl'
        with open(metrics_path, 'wb') as f:
            pickle.dump(self.detailed_metrics, f)

        if self.time_series_log:
            ts_df = pd.DataFrame(self.time_series_log)
            ts_path = self.cfg.experiment_dir / f'time_series_epoch_{epoch}.csv'
            ts_df.to_csv(ts_path, index=False)
            logger.info(f"Saved time-series to {ts_path}")

        # Also save as CSV for easy analysis (pad ragged lists)
        cols = {k: v for k, v in self.detailed_metrics.items() if isinstance(v, list) and len(v) > 0}
        if cols:
            max_len = max(len(v) for v in cols.values())
            # Pad all lists to same length with NaN
            padded_cols = {}
            for k, v in cols.items():
                padded_cols[k] = v + [float('nan')] * (max_len - len(v))
            df_metrics = pd.DataFrame(padded_cols)
            csv_path = self.cfg.experiment_dir / f'training_metrics_epoch_{epoch}.csv'
            df_metrics.to_csv(csv_path, index=False)

        # Save group-wise metrics (also pad)
        group_dict = self.detailed_metrics.get('group_robust_acc', {})
        if group_dict:
            max_glen = max(len(v) for v in group_dict.values()) if group_dict else 0
            padded_groups = {}
            for k, v in group_dict.items():
                padded_groups[f'group_{k}_robust_acc'] = v + [float('nan')] * (max_glen - len(v))
            group_df = pd.DataFrame(padded_groups)
            if not group_df.empty:
                group_csv = self.cfg.experiment_dir / f'group_metrics_epoch_{epoch}.csv'
                group_df.to_csv(group_csv, index=False)

                # Save PGD efficiency analysis
        pgd_analysis = {
            'epoch': epoch,
            'total_pgd_calls': self.pgd_calls_train,
            'per_cluster_calls': dict(self.pgd_calls_per_cluster),
            'per_cluster_robust': {k: np.mean(v[-100:]) if v else 0
                                   for k, v in self.robust_acc_per_cluster.items()},
            'cluster_difficulties': dict(self.cluster_difficulties),
            'expected_calls': self.cfg.pgd_steps * (epoch) * len(self.train_set) // self.cfg.batch_size,  # <-- use epoch not epoch+1
            'saved_percentage': 100 * (1 - self.pgd_calls_train /
                                      max(1, self.cfg.pgd_steps * max(1, epoch) * len(self.train_set) // self.cfg.batch_size))
        }

        analysis_path = self.cfg.experiment_dir / f'pgd_analysis_epoch_{epoch}.json'
        with open(analysis_path, 'w') as f:
            json.dump(pgd_analysis, f, indent=2, default=str)

        logger.info(f"Saved detailed metrics to {self.cfg.experiment_dir}")


    def evaluate(self, dataloader=None) -> Dict[str, float]:
        """Evaluate model on given dataloader (defaults to val_loader)"""
        if dataloader is None:
            if not hasattr(self, 'val_loader'):
                raise ValueError("No dataloader provided and no val_loader set")
            dataloader = self.val_loader

        model_eval = self.model_ema if self.model_ema is not None else self.model

        # Calibrate BN if needed
        if self.cfg.calibrate_bn and hasattr(self, 'train_loader'):
            self.calibrate_bn(model_eval, self.train_loader, self.cfg.calibration_steps)

        model_eval.eval()
        logger.info(f"Eval config: eps={self.cfg.epsilon:.4f}, steps={self.cfg.eval_pgd_steps}, restarts={self.cfg.eval_pgd_restarts}")
        clean_correct = 0
        robust_correct = 0
        total = 0

        with torch.no_grad():
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)

                # Clean accuracy
                with UseAdvBN(model_eval, False):
                    logits_clean = model_eval(Attacks.normalize(x, self.mean, self.std))
                clean_correct += (logits_clean.argmax(1) == y).sum().item()

                # Robust accuracy (multi-restart)
                batch_robust = torch.ones(x.size(0), dtype=torch.bool, device=self.device)

                for restart in range(self.cfg.eval_pgd_restarts):
                    x_adv, _ = Attacks.pgd(
                        model_eval, x, y,
                        self.cfg.epsilon,
                        self.cfg.pgd_step_size,
                        self.cfg.eval_pgd_steps,
                        self.mean, self.std,
                        random_start=True,
                        use_adv_bn=True,
                        early_stop=False
                    )

                    with UseAdvBN(model_eval, True):
                        logits_adv = model_eval(Attacks.normalize(x_adv, self.mean, self.std))

                    batch_robust &= (logits_adv.argmax(1) == y)
                    self.pgd_calls_eval += self.cfg.eval_pgd_steps * x.size(0)

                robust_correct += batch_robust.sum().item()
                total += x.size(0)

        return {
            'clean': clean_correct / max(1, total),
            'robust': robust_correct / max(1, total)
        }

    def evaluate_autoattack(self, dataloader=None, epsilon=None):
        """Run AutoAttack evaluation"""
        if dataloader is None:
            dataloader = getattr(self, 'test_loader', self.val_loader)

        if epsilon is None:
            epsilon = self.cfg.epsilon

        try:
            from autoattack import AutoAttack
        except ImportError:
            logger.warning("AutoAttack not installed. Skipping AA evaluation.")
            return {'aa_robust': -1}

        model_eval = self.model_ema if self.model_ema is not None else self.model
        model_eval.eval()

        # Collect subset for AA (first 1000 samples)
        x_test, y_test = [], []
        n_collected = 0
        for i, (x, y) in enumerate(dataloader):
            x_test.append(x)
            y_test.append(y)
            n_collected += x.size(0)
            if n_collected >= 1000:
                break

        x_test = torch.cat(x_test, 0)[:1000].to(self.device)
        y_test = torch.cat(y_test, 0)[:1000].to(self.device)

        adversary = AutoAttack(model_eval, norm='Linf', eps=epsilon,
                               version='standard', device=self.device)
        adversary.attacks_to_run = ['apgd-ce', 'apgd-t']  # Faster subset

        x_adv = adversary.run_standard_evaluation(x_test, y_test, bs=128)

        with torch.no_grad():
            outputs = model_eval(x_adv)
            acc = (outputs.argmax(1) == y_test).float().mean().item()

        return {'aa_robust': acc}


    def _evaluate_feature_for_learning(self, feature_name, features, sample_size=500):
        """Test how well a feature type creates learnable clusters"""

        # Sample for efficiency
        n_samples = min(sample_size, len(features))
        indices = np.random.choice(len(features), n_samples, replace=False)
        sampled_features = features[indices]

        # Try different numbers of clusters to find natural groupings
        best_score = -float('inf')
        best_k = 2

        for k in [2, 3, 4, 5]:
            from sklearn.cluster import MiniBatchKMeans
            kmeans = MiniBatchKMeans(n_clusters=k, random_state=42, max_iter=30)
            labels = kmeans.fit_predict(sampled_features.cpu().numpy())

            # Test if these clusters represent different learning patterns
            diversity_score = self._compute_cluster_diversity(indices, labels, k)
            consistency_score = self._compute_within_cluster_consistency(indices, labels, k)
            learning_potential = self._test_learning_potential(indices, labels, k)

            combined_score = diversity_score + 0.3 * consistency_score + 0.5 * learning_potential

            if combined_score > best_score:
                best_score = combined_score
                best_k = k

        return best_score, best_k

    def _compute_cluster_diversity(self, indices, labels, k):
        """Measure how different the clusters are from each other"""
        cluster_responses = defaultdict(list)

        for idx, label in zip(indices[:100], labels[:100]):  # Sample for speed
            x, y = self.train_set[idx]
            x = x.unsqueeze(0).to(self.device)
            y = torch.tensor([y]).to(self.device)

            with torch.no_grad():
                # Test response to weak perturbation
                x_adv = Attacks.fgsm(self.model, x, y, self.cfg.epsilon * 0.5,
                                    self.mean, self.std, use_adv_bn=True)
                logits_clean = self.model(Attacks.normalize(x, self.mean, self.std))
                logits_adv = self.model(Attacks.normalize(x_adv, self.mean, self.std))

                response = {
                    'robust': (logits_adv.argmax(1) == y).item(),
                    'confidence_drop': (F.softmax(logits_clean, dim=1).max() -
                                      F.softmax(logits_adv, dim=1).max()).item()
                }
                cluster_responses[label].append(response)

        # Compute between-cluster variance
        cluster_means = {}
        for label, responses in cluster_responses.items():
            if responses:
                cluster_means[label] = {
                    'robust': np.mean([r['robust'] for r in responses]),
                    'conf_drop': np.mean([r['confidence_drop'] for r in responses])
                }

        if len(cluster_means) < 2:
            return 0.0

        # Variance between cluster means
        robust_values = [m['robust'] for m in cluster_means.values()]
        diversity = np.std(robust_values) if len(robust_values) > 1 else 0.0

        return diversity

    def _compute_within_cluster_consistency(self, indices, labels, k):
        """Measure how consistent samples are within each cluster"""
        cluster_variances = []

        for label in range(k):
            cluster_indices = [indices[i] for i, l in enumerate(labels) if l == label]
            if len(cluster_indices) < 2:
                continue

            # Get gradient norms for consistency check
            grad_norms = []
            for idx in cluster_indices[:20]:  # Sample
                x, y = self.train_set[idx]
                x = x.unsqueeze(0).to(self.device).requires_grad_(True)
                y = torch.tensor([y]).to(self.device)

                loss = F.cross_entropy(self.model(Attacks.normalize(x, self.mean, self.std)), y)
                grad = torch.autograd.grad(loss, x)[0]
                grad_norms.append(grad.norm().item())

            if grad_norms:
                # Lower variance = higher consistency
                variance = np.std(grad_norms) / (np.mean(grad_norms) + 1e-8)
                cluster_variances.append(1.0 / (1.0 + variance))

        return np.mean(cluster_variances) if cluster_variances else 0.0

    def _test_learning_potential(self, indices, labels, k):
        """Test if ordering these clusters improves learning"""
        # Quick test: does a simple progression through clusters show improvement?
        cluster_order = list(range(k))
        improvements = []

        prev_loss = None
        for cluster_id in cluster_order:
            cluster_indices = [indices[i] for i, l in enumerate(labels) if l == cluster_id]
            if not cluster_indices:
                continue

            # Test a small batch from this cluster
            batch_indices = cluster_indices[:min(16, len(cluster_indices))]
            x = torch.stack([self.train_set[idx][0] for idx in batch_indices])
            y = torch.tensor([self.train_set[idx][1] for idx in batch_indices])
            x, y = x.to(self.device), y.to(self.device)

            with torch.no_grad():
                loss = F.cross_entropy(self.model(Attacks.normalize(x, self.mean, self.std)), y)

            if prev_loss is not None:
                improvement = (prev_loss - loss).item()
                improvements.append(improvement)
            prev_loss = loss

        return np.mean(improvements) if improvements else 0.0

    def _pgd_with_count(self, x, y, eval_mode: bool = False):
        steps = self.cfg.eval_pgd_steps if eval_mode else self.cfg.pgd_steps
        step_size = self.cfg.pgd_step_size
        eps = self.cfg.epsilon
        used_steps = 0
        was_train = self.model.training
        self.model.eval()
        x0 = x.detach()
        x_adv = (x0 + torch.empty_like(x0).uniform_(-eps, eps)).clamp(0, 1)
        with UseAdvBN(self.model, True):
            for t in range(steps):
                x_adv = x_adv.detach().requires_grad_(True)
                logits = self.model(Attacks.normalize(x_adv, self.mean, self.std))
                loss = F.cross_entropy(logits, y)
                grad = torch.autograd.grad(loss, x_adv)[0]
                x_adv = (x_adv + step_size * grad.sign()).clamp(x0 - eps, x0 + eps).clamp(0, 1).detach()
                used_steps += 1
                if (not eval_mode) and self.cfg.early_stop_pgd and used_steps >= self.cfg.min_early_stop_steps:
                    with torch.no_grad():
                        pred = self.model(Attacks.normalize(x_adv, self.mean, self.std)).argmax(1)
                        if (pred != y).all(): break
        self.model.train(was_train)
        return x_adv, used_steps

    def _trades_step(self, x_clean, x_adv, y):
        self.optimizer.zero_grad(set_to_none=True)
        beta = self._current_trades_beta()
        with UseAdvBN(self.model, False):
            lc = self.model(Attacks.normalize(x_clean, self.mean, self.std))
        with UseAdvBN(self.model, True):
            la = self.model(Attacks.normalize(x_adv, self.mean, self.std))
        loss_nat = F.cross_entropy(lc, y)
        with torch.no_grad():
            p_clean = F.softmax(lc, dim=1)
        loss_kld = F.kl_div(F.log_softmax(la, dim=1), p_clean, reduction='batchmean')
        loss = loss_nat + beta * loss_kld
        loss.backward()
        if self.cfg.grad_clip_norm and self.cfg.grad_clip_norm > 0:
            nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)
        self.optimizer.step()
        if self.model_ema is not None: self._ema_update()
        return float(loss.item())



    def _save_checkpoint(self, epoch):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'ema_state_dict': self.model_ema.state_dict() if self.model_ema else None,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.lr_scheduler.state_dict(),
            'metrics': dict(self.metrics),
            'pgd_calls_train': self.pgd_calls_train,
            'config': asdict(self.cfg)
        }

        path = self.cfg.experiment_dir / f'checkpoint_epoch_{epoch}.pth'
        torch.save(checkpoint, path)
        logger.info(f"Saved checkpoint to {path}")

    def _write_ablation_results(self, results):
        """Write ablation comparison results"""
        csv_path = self.cfg.experiment_dir / 'ablation_results.csv'

        if results:
            with open(csv_path, 'w', newline='') as f:
                writer = csv.DictWriter(f, fieldnames=results[0].keys())
                writer.writeheader()
                writer.writerows(results)

            logger.info(f"Wrote ablation results to {csv_path}")



# ========================= Statistical Analysis =========================
class ExperimentAnalyzer:
    """Analyze results across multiple runs for statistical significance"""

    @staticmethod
    def compute_auc_robust_vs_pgd(time_series_logs):
        """Compute AUC of robust accuracy vs PGD calls curve"""
        from sklearn.metrics import auc

        calls = [entry['pgd_calls_cumulative'] for entry in time_series_logs]
        robust = [entry['robust_acc'] for entry in time_series_logs]

        if len(calls) < 2:
            return 0.0

        # Normalize calls to [0, 1] range
        max_calls = max(calls)
        if max_calls > 0:
            calls_norm = [c / max_calls for c in calls]
            return auc(calls_norm, robust)
        return 0.0

    @staticmethod
    def compute_time_to_target(time_series_logs, target_acc=0.40):
        """Find first step reaching target robust accuracy"""
        for entry in time_series_logs:
            if entry['robust_acc'] >= target_acc:
                return {
                    'steps': entry['step'],
                    'pgd_calls': entry['pgd_calls_cumulative'],
                    'wall_time': entry['wall_time']
                }
        return None

    @staticmethod
    def compute_cluster_stability(labels_by_epoch):
        """Compute ARI between consecutive epochs"""
        from sklearn.metrics import adjusted_rand_score

        stability_scores = []
        for i in range(len(labels_by_epoch) - 1):
            ari = adjusted_rand_score(labels_by_epoch[i], labels_by_epoch[i+1])
            stability_scores.append(ari)
        return stability_scores

    @staticmethod
    def bootstrap_ci(data, n_bootstrap=1000, ci=0.95):
        """Compute bootstrap confidence interval"""
        import numpy as np

        bootstrapped = []
        for _ in range(n_bootstrap):
            sample = np.random.choice(data, size=len(data), replace=True)
            bootstrapped.append(np.mean(sample))

        alpha = (1 - ci) / 2
        lower = np.percentile(bootstrapped, alpha * 100)
        upper = np.percentile(bootstrapped, (1 - alpha) * 100)

        return {
            'mean': np.mean(data),
            'std': np.std(data),
            'ci_lower': lower,
            'ci_upper': upper
        }



class SimCLREncoder(nn.Module):
    """SimCLR encoder for feature extraction"""
    def __init__(self, base_model='resnet18', feature_dim=128):
        super().__init__()

        # Base encoder
        if base_model == 'resnet18':
            resnet = models.resnet18(weights=None)
            self.encoder = nn.Sequential(*list(resnet.children())[:-1])
            hidden_dim = 512
        else:
            raise NotImplementedError

        # Projection head
        self.projector = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, feature_dim)
        )

    def forward(self, x):
        h = self.encoder(x).squeeze()
        z = self.projector(h)
        return h, z  # Return both representation and projection

    @torch.no_grad()
    def extract_features(self, x):
        """Extract features without projection head"""
        return self.encoder(x).squeeze()


# ========================= Data Loading =========================
def get_dataloaders(cfg):
    """Create data loaders"""
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    transform_test = transforms.ToTensor()

    trainset = torchvision.datasets.CIFAR10(
        root=cfg.data_root, train=True, download=True, transform=transform_train
    )

    testset = torchvision.datasets.CIFAR10(
        root=cfg.data_root, train=False, download=True, transform=transform_test
    )

    # Split train/val
    n_train = int(len(trainset) * (1 - cfg.val_split))
    n_val = len(trainset) - n_train

    torch.manual_seed(cfg.seed)
    train_subset, val_subset = torch.utils.data.random_split(
        trainset, [n_train, n_val]
    )

    train_loader = DataLoader(
        train_subset, batch_size=cfg.batch_size,
        shuffle=True, num_workers=cfg.num_workers, pin_memory=True
    )

    val_loader = DataLoader(
        val_subset, batch_size=cfg.batch_size,
        shuffle=False, num_workers=cfg.num_workers, pin_memory=True
    )

    test_loader = DataLoader(
        testset, batch_size=cfg.batch_size,
        shuffle=False, num_workers=cfg.num_workers, pin_memory=True
    )

    return train_loader, val_loader, test_loader

class ClusterOrderingTester:
    """Tests different cluster orderings during teacher training to find optimal curriculum"""

    def __init__(self, trainer):
        self.trainer = trainer
        self.ordering_performance = {}
        self.tested_orderings = []

    def test_ordering_online(self, current_epoch, groups, order, metrics):
        """Track how well current ordering is performing"""
        order_key = tuple(order[:10])  # Use first 10 as signature

        if order_key not in self.ordering_performance:
            self.ordering_performance[order_key] = []

        self.ordering_performance[order_key].append({
            'epoch': current_epoch,
            'robust_acc': metrics.get('acc_robust', 0),
            'loss': metrics.get('loss', float('inf')),
            'efficiency': metrics.get('steps_saved', 0) / max(1, metrics.get('steps_used', 10))
        })

    def find_best_orderings(self, groups, n_test=5, test_epochs=2):
        """Test multiple orderings to find which work best"""
        logger.info(f"Testing {n_test} different orderings for {test_epochs} epochs each...")

        candidate_orderings = self._generate_candidate_orderings(groups, n_test)
        ordering_scores = {}

        # Save current model state
        original_model_state = copy.deepcopy(self.trainer.model.state_dict())

        for i, (name, order) in enumerate(candidate_orderings):
            logger.info(f"Testing ordering {i+1}/{n_test}: {name}")

            # Reset model to same starting point
            self.trainer.model.load_state_dict(copy.deepcopy(original_model_state))

            # Mini training loop with this ordering
            score = self._evaluate_ordering(groups, order, test_epochs)
            ordering_scores[name] = score

            logger.info(f"  Score: {score:.4f}")

        # Restore original model
        self.trainer.model.load_state_dict(original_model_state)

        # Find best
        best_item = max(ordering_scores.items(), key=lambda x: x[1])
        best_name = best_item[0]

        # Find the actual sequence for this ordering name
        for name, sequence in candidate_orderings:
            if name == best_name:
                return (name, sequence), ordering_scores

        # Fallback if not found
        return ("cyclical", list(range(len(groups))) * 100), ordering_scores

    def _generate_candidate_orderings(self, groups, n_candidates):
        """Generate different ordering strategies to test"""
        orderings = []
        n_clusters = len(groups)
        n_batches = sum(len(g) for g in groups.values()) // self.trainer.cfg.batch_size

        # 1. Foundation-first (start with easiest)
        foundation = self._find_foundation_cluster(groups)
        order = [foundation] * (n_batches // 4)
        for i in range(n_batches - len(order)):
            order.append((foundation + 1 + i) % n_clusters)
        orderings.append(("foundation_first", order))

        # 2. Cyclical (0->1->2->...->0)
        cycle = []
        for i in range(n_batches):
            cycle.append(i % n_clusters)
        orderings.append(("cyclical", cycle))

        # 3. Interleaved pairs
        pairs = []
        for i in range(n_batches):
            if i % 4 < 2:
                pairs.append(i % 2)  # Alternate between 0,1
            else:
                pairs.append(2 + (i % (n_clusters - 2)))
        orderings.append(("interleaved_pairs", pairs))

        # 4. Progressive (gradually introduce harder)
        progressive = []
        phase_length = n_batches // n_clusters
        for phase in range(n_clusters):
            for _ in range(phase_length):
                # Sample from clusters 0 to phase
                progressive.append(np.random.randint(0, phase + 1))
        progressive.extend([np.random.randint(n_clusters) for _ in range(n_batches - len(progressive))])
        orderings.append(("progressive", progressive))

        # 5. Discovered transition-based (if T matrix exists)
        if hasattr(self.trainer, 'T') and self.trainer.T is not None:
            T = self.trainer.T.cpu().numpy()
            transition_order = []
            current = 0
            for _ in range(n_batches):
                probs = T[current]
                probs = probs / (probs.sum() + 1e-8)
                current = np.random.choice(len(probs), p=probs)
                transition_order.append(current)
            orderings.append(("transition_based", transition_order))

        return orderings[:n_candidates]

    def _evaluate_ordering(self, groups, order, epochs):
        """Run mini training with specific ordering"""
        total_score = 0

        for epoch in range(epochs):
            epoch_robust = []
            epoch_loss = []

            # Create ordered loader
            train_loader_ordered = DataLoader(
                self.trainer.train_set,
                batch_sampler=GroupBatchSampler(
                    groups, order[:100],  # Test with first 100 batches
                    self.trainer.cfg.batch_size,
                    self.trainer.train_set, drop_small=False
                ),
                num_workers=0
            )

            for step, (x, y) in enumerate(train_loader_ordered):
                if step >= 100:  # Quick test
                    break

                x, y = x.to(self.trainer.device), y.to(self.trainer.device)
                cluster_id = order[step] if step < len(order) else 0

                # Quick training step
                metrics = self.trainer.train_step(x, y, cluster_id)
                epoch_robust.append(metrics['acc_robust'])
                epoch_loss.append(metrics['loss'])

            # Score this epoch
            robust_improvement = np.mean(epoch_robust) - (0.3 if epoch == 0 else total_score / max(1, epoch))
            loss_decrease = 2.0 - np.mean(epoch_loss)

            epoch_score = robust_improvement + 0.3 * loss_decrease
            total_score += epoch_score

        return total_score / epochs

    def _find_foundation_cluster(self, groups):
        """Identify which cluster is most foundational"""
        # Simple heuristic: largest cluster or most central
        return max(groups.items(), key=lambda x: len(x[1]))[0]

# ========================= Main =========================
def main():
    import argparse

    parser = argparse.ArgumentParser(description="LOAT: Latent-Order Adversarial Training")

    # Experiment settings
    parser.add_argument("--experiment_name", type=str, default="loat_cifar10",
                        help="Name for this experiment run")
    parser.add_argument("--seed", type=int, default=1337,
                        help="Random seed for reproducibility")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device to use (cuda/cpu)")


    # Dataset settings
    parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10"],
                        help="Dataset to use")
    parser.add_argument("--batch_size", type=int, default=128,
                        help="Training batch size")
    parser.add_argument("--num_workers", type=int, default=2,
                        help="Number of data loading workers")
    parser.add_argument("--val_split", type=float, default=0.05,
                        help="Validation split ratio")

    # Model settings
    parser.add_argument("--model_name", type=str, default="resnet18",
                        help="Model architecture")
    parser.add_argument("--use_dual_bn", action="store_true", default=True,
                        help="Use dual batch normalization")
    parser.add_argument("--excluded_cluster", type=int, default=None, 
                        help="Cluster ID to exclude from training (for holdout testing)")

    parser.add_argument("--mode", type=str, default="baseline",
                        choices=["teacher", "student", "baseline"],
                        help="Training mode: teacher discovery, student with recipe, or baseline")
    parser.add_argument("--recipe", type=str, default=None,
                        help="Path to teacher recipe (for student mode)")
    parser.add_argument("--teacher_epochs", type=int, default=30,
                        help="Epochs for teacher training")
    parser.add_argument("--use_adversarial_ae", action="store_true",
                        help="Use adversarial autoencoder")
    parser.add_argument("--use_contrastive", action="store_true",
                        help="Use contrastive clustering")

    # Training settings
    parser.add_argument("--epochs", type=int, default=30,
                        help="Number of training epochs")
    parser.add_argument("--warmup_epochs", type=int, default=2,
                        help="Number of warmup epochs")
    parser.add_argument("--lr", type=float, default=0.1,
                        help="Initial learning rate")
    parser.add_argument("--momentum", type=float, default=0.9,
                        help="SGD momentum")
    parser.add_argument("--weight_decay", type=float, default=5e-4,
                        help="Weight decay")
    parser.add_argument("--grad_clip_norm", type=float, default=10.0,
                        help="Gradient clipping norm (0 to disable)")
    parser.add_argument("--lr_schedule", type=str, default="multistep",
                        choices=["multistep", "cosine"],
                        help="Learning rate schedule")
    parser.add_argument("--lr_milestones", type=float, nargs="+", default=[0.5, 0.75],
                        help="LR decay milestones (as fractions of total epochs)")

    parser.add_argument("--t_matrix_mode", type=str, default="throughout",
                        choices=["throughout", "late", "converged"],
                        help="When to build transition matrix")
    parser.add_argument("--t_matrix_start_epoch", type=int, default=20,
                        help="Epoch to start recording transitions (for 'late' mode)")
    parser.add_argument("--t_matrix_convergence_window", type=int, default=5,
                        help="Final epochs to record transitions (for 'converged' mode)")


    # EMA settings
    parser.add_argument("--ema_enabled", action="store_true", default=True,
                        help="Enable EMA model")
    parser.add_argument("--ema_decay", type=float, default=0.999,
                        help="EMA decay rate")

    # Adversarial settings
    parser.add_argument("--epsilon", type=float, default=8/255,
                        help="Adversarial perturbation budget")
    parser.add_argument("--pgd_steps", type=int, default=10,
                        help="PGD steps during training")
    parser.add_argument("--pgd_step_size", type=float, default=2/255,
                        help="PGD step size")
    parser.add_argument("--eval_pgd_steps", type=int, default=20,
                        help="PGD steps during evaluation")
    parser.add_argument("--eval_pgd_restarts", type=int, default=2,
                        help="Number of random restarts for evaluation")
    parser.add_argument("--early_stop_pgd", type=bool, default=True,
                        help="Enable early stopping in PGD (True/False)")
    parser.add_argument("--min_early_stop_steps", type=int, default=3,
                        help="Minimum steps before early stopping")

    # TRADES settings
    parser.add_argument("--trades_beta", type=float, default=6.0,
                        help="TRADES beta parameter")
    parser.add_argument("--beta_warmup_epochs", type=int, default=2,
                        help="Epochs to warmup beta from 0")

    # Autoencoder settings
    parser.add_argument("--ae_latent_dim", type=int, default=64,
                        help="Autoencoder latent dimension")
    parser.add_argument("--ae_train_epochs", type=int, default=2,
                        help="Autoencoder pretraining epochs")
    parser.add_argument("--ae_lr", type=float, default=0.001,
                        help="Autoencoder learning rate")
    parser.add_argument("--use_denoising", action="store_true", default=True,
                        help="Use denoising autoencoder")
    parser.add_argument("--noise_level", type=float, default=0.05,
                        help="Noise level for denoising AE")
    parser.add_argument("--use_fgsm_noise", action="store_true", default=True,
                        help="Add FGSM noise to denoising AE")

    # Geometry features settings
    parser.add_argument("--codebook_size", type=int, default=32,
                        help="Size of embedding codebook")
    parser.add_argument("--n_prototypes", type=int, default=8,
                        help="Number of prototypes for OT")
    parser.add_argument("--n_slices", type=int, default=32,
                        help="Number of slices for sliced Wasserstein")
    parser.add_argument("--boe_temperature", type=float, default=0.5,
                        help="Temperature for bag-of-embeddings")
    parser.add_argument("--use_fft", action="store_true", default=True,
                        help="Use FFT features")
    parser.add_argument("--use_gram", action="store_true", default=True,
                        help="Use Gram matrix features")
    parser.add_argument("--use_topo", action="store_true", default=False,
                        help="Use topological features")
    parser.add_argument("--class_agnostic", action="store_true", default=False,
                        help="Use class-agnostic stats")
    parser.add_argument("--cluster_feature_type", type=str, default="multi_view",
                        choices=["multi_view", "stats_only", "geom_only", "confidence",
                                "adv_dynamics", "loss_landscape", "grad_coherence",
                                "activations", "consistency",
                                "adaptive_comprehensive"],
                        help="Feature type for clustering")

    # Clustering settings
    parser.add_argument("--n_clusters", type=int, default=5,
                        help="Number of final clusters")
    parser.add_argument("--K_stats", type=int, default=5,
                        help="Number of stats view clusters")
    parser.add_argument("--K_geom", type=int, default=5,
                        help="Number of geometry view clusters")
    parser.add_argument("--use_multiview", action="store_true", default=True,
                        help="Use multi-view clustering")
    parser.add_argument("--mv_mode", type=str, default="coreg",
                        choices=["consensus", "coreg"],
                        help="Multi-view clustering mode")
    parser.add_argument("--coreg_alpha", type=float, default=0.2,
                        help="Co-regularization strength")
    parser.add_argument("--coreg_iters", type=int, default=5,
                        help="Co-regularization iterations")

    # Discovery settings
    parser.add_argument("--discovery_interval", type=int, default=3,
                        help="Epochs between cluster discovery")
    parser.add_argument("--cache_embeddings", action="store_true", default=True,
                        help="Cache embeddings per epoch")

    # Scheduling settings
    parser.add_argument("--use_ordering", action="store_true",
                        help="Use adaptive batch ordering")
    parser.add_argument("--use_cycles", action="store_true", default=False,
                        help="Test cyclic orderings")
    parser.add_argument("--test_cycle_modes", action="store_true", default=True,
                        help="Test different cycle modes")
    parser.add_argument("--block_size", type=int, default=10,
                        help="Block size for rewards")
    parser.add_argument("--probe_interval", type=int, default=20,
                        help="Interval for lightweight probes")
    parser.add_argument("--random_batch_ratio", type=float, default=0.15,
                        help="Ratio of random batches")

    # Bandit settings
    parser.add_argument("--ucb_c", type=float, default=1.5,
                        help="UCB exploration constant")
    parser.add_argument("--warmup_blocks", type=int, default=5,
                        help="Warmup blocks for bandit")

    # Ablation settings
    parser.add_argument("--ablation_mode", type=str, default="full",
                        choices=["full", "stats_only", "geom_only", "single_cluster",
                                "uniform_mix", "random_clusters"],
                        help="Ablation mode to test")

    # Evaluation settings
    parser.add_argument("--eval_interval", type=int, default=1,
                        help="Epochs between evaluations")
    parser.add_argument("--calibrate_bn", action="store_true", default=True,
                        help="Calibrate BN before evaluation")
    parser.add_argument("--calibration_steps", type=int, default=16,
                        help="Steps for BN calibration")

    # Output settings
    parser.add_argument("--output_dir", type=str, default="./experiments_loat",
                        help="Output directory for experiments")
    parser.add_argument("--log_interval", type=int, default=50,
                        help="Iterations between logging")
    parser.add_argument("--save_interval", type=int, default=5,
                        help="Epochs between checkpoints")
    parser.add_argument("--hitl_enabled", action="store_true", default=True,
                        help="Enable human-in-the-loop reporting")

    parser.add_argument("--data_root", type=str, default="./data",
                    help="Path to dataset root directory")
    parser.add_argument("--single_cluster_id", type=int, default=0)
    parser.add_argument("--order_min_edge", type=float, default=0.10)  # threshold for edges in T
    parser.add_argument("--order_beam", type=int, default=2)  # for natural/beam-like following
    parser.add_argument("--order_update_every", type=int, default=1)  # epochs between T recompute
    parser.add_argument("--log_paths", action="store_true", default=True)
    parser.add_argument("--adaptive_pgd", type=bool, default=False,
                        help="Use adaptive PGD based on cluster difficulty")
    parser.add_argument("--run_autoattack", action="store_true", default=False,
                    help="Run AutoAttack evaluation")
    parser.add_argument("--autoattack_freq", type=int, default=5,
                    help="Run AutoAttack every N epochs")


    parser.add_argument("--pgd_budget_total", type=int, default=None,
                        help="Total PGD budget (None for unlimited)")
    parser.add_argument("--pgd_budget_mode", type=str, default="none",
                        choices=["none", "stop", "throttle"],
                        help="How to handle budget exhaustion")
    parser.add_argument("--order_mode", type=str, default="uniform",
                        choices=["uniform","single_cluster","natural","reverse_cycle",
                                "spectral","gradient_guided","ucb","loat","cycle_per_epoch","sequential","custom"],
                        help="Batch ordering mode")
    parser.add_argument("--use_simclr", action="store_true",
                        help="Use SimCLR pre-training for initial features")
    parser.add_argument("--simclr_epochs", type=int, default=20,
                        help="Epochs for SimCLR pre-training")
    parser.add_argument("--custom_order", type=str, default=None,
                    help="Custom cluster order as comma-separated list (e.g., '0,1,0,1')")


    # Parse arguments
    args = parser.parse_args()

    # MODE-SPECIFIC DEFAULTS - ADD THIS BLOCK BEFORE CONFIG CREATION
    if args.mode == "teacher":
        # Teacher needs discovery and UCB exploration
        if args.discovery_interval == 3:  # If using default
            args.discovery_interval = 5  # Wait for more data
        if args.order_mode == "uniform":  # If using default
            args.order_mode = "ucb"  # Teacher explores with UCB

    elif args.mode == "student":
        # Student uses recipe and follows paths
        if not args.recipe:
            raise ValueError("Student mode requires --recipe path")
        args.discovery_interval = 999  # Disable discovery
        # Keep the user's specified order_mode, don't override it
        logger.info(f"Student mode: using {args.order_mode} ordering")

    else:  # baseline
        args.discovery_interval = 999  # No discovery
        # Keep uniform ordering


    # Create config from arguments
    cfg = Config(
        mode=args.mode,  # ADD THIS
        recipe_path=args.recipe,  # ADD THIS
        use_recipe=(args.mode == "student" and args.recipe is not None),
        teacher_epochs=args.teacher_epochs,  # ADD THIS
        use_adversarial_ae=args.use_adversarial_ae,  # ADD THIS
        use_contrastive=args.use_contrastive,  # ADD THIS
        order_mode=args.order_mode,
        excluded_cluster=args.excluded_cluster,
        custom_order=args.custom_order,
        pgd_budget_total=args.pgd_budget_total,
        pgd_budget_mode=args.pgd_budget_mode,
        order_min_edge=args.order_min_edge,
        order_beam=args.order_beam,
        order_update_every=args.order_update_every,
        log_paths=args.log_paths,
        experiment_name=args.experiment_name,
        seed=args.seed,
        device=args.device,
        dataset=args.dataset,
        data_root=args.data_root,
        batch_size=args.batch_size,
        num_workers=0,  # args.num_workers,
        val_split=args.val_split,
        model_name=args.model_name,
        use_dual_bn=args.use_dual_bn,
        epochs=args.epochs,
        warmup_epochs=args.warmup_epochs,
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        grad_clip_norm=args.grad_clip_norm,
        lr_schedule=args.lr_schedule,
        lr_milestones=args.lr_milestones,
        ema_enabled=args.ema_enabled,
        ema_decay=args.ema_decay,
        epsilon=args.epsilon,
        pgd_steps=args.pgd_steps,
        pgd_step_size=args.pgd_step_size,
        eval_pgd_steps=args.eval_pgd_steps,
        eval_pgd_restarts=args.eval_pgd_restarts,
        early_stop_pgd=args.early_stop_pgd,
        min_early_stop_steps=args.min_early_stop_steps,
        trades_beta=args.trades_beta,
        beta_warmup_epochs=args.beta_warmup_epochs,
        ae_latent_dim=args.ae_latent_dim,
        ae_train_epochs=args.ae_train_epochs,
        ae_lr=args.ae_lr,
        use_denoising=args.use_denoising,
        noise_level=args.noise_level,
        use_fgsm_noise=args.use_fgsm_noise,
        codebook_size=args.codebook_size,
        n_prototypes=args.n_prototypes,
        n_slices=args.n_slices,
        boe_temperature=args.boe_temperature,
        use_fft=args.use_fft,
        use_gram=args.use_gram,
        use_topo=args.use_topo,
        class_agnostic=args.class_agnostic,
        cluster_feature_type=args.cluster_feature_type,
        n_clusters=args.n_clusters,
        K_stats=args.K_stats,
        K_geom=args.K_geom,
        use_multiview=args.use_multiview,
        mv_mode=args.mv_mode,
        coreg_alpha=args.coreg_alpha,
        coreg_iters=args.coreg_iters,
        discovery_interval=args.discovery_interval,
        cache_embeddings=args.cache_embeddings,
        use_ordering=args.use_ordering,
        use_cycles=args.use_cycles,
        test_cycle_modes=args.test_cycle_modes,
        block_size=args.block_size,
        probe_interval=args.probe_interval,
        random_batch_ratio=args.random_batch_ratio,
        ucb_c=args.ucb_c,
        warmup_blocks=args.warmup_blocks,
        ablation_mode=args.ablation_mode,
        single_cluster_id=args.single_cluster_id,
        eval_interval=args.eval_interval,
        calibrate_bn=args.calibrate_bn,
        calibration_steps=args.calibration_steps,
        output_dir=args.output_dir,
        log_interval=args.log_interval,
        save_interval=args.save_interval,
        hitl_enabled=args.hitl_enabled,
        adaptive_pgd=args.adaptive_pgd,
        run_autoattack=args.run_autoattack,
        autoattack_freq=args.autoattack_freq,
        t_matrix_mode=args.t_matrix_mode,  # ADD THIS LINE
        t_matrix_convergence_window=args.t_matrix_convergence_window,
    )


    # Additional mode-specific overrides
    if cfg.mode == "teacher":
        cfg.discovery_interval = 10  # Discover at epoch 10
        cfg.t_matrix_mode = "late"    # Start recording after discovery
        cfg.t_matrix_start_epoch = 15 # Record from epoch 15 onwards
        cfg.cluster_feature_type = "adaptive_comprehensive"
        cfg.early_stop_pgd = False
        cfg.adaptive_pgd = False  # Teacher uses full PGD for profiling
        # cfg.t_matrix_mode = "converged"  # Focus on final epochs
        # cfg.t_matrix_convergence_window = 5
        # cfg.epochs = args.teacher_epochs  # Use teacher epochs
        cfg.use_simclr = args.use_simclr
        logger.info(f"Teacher mode: {cfg.epochs} epochs, discovery every {cfg.discovery_interval} epochs, "
                    f"using {cfg.order_mode} ordering, T-matrix mode: {cfg.t_matrix_mode}")

    elif cfg.mode == "student":
        cfg.adaptive_pgd = True  # Student uses adaptive PGD
        cfg.early_stop_pgd = True
        cfg.min_early_stop_steps = 3
        logger.info(f"Student mode: using recipe from {cfg.recipe_path}, "
                    f"adaptive PGD enabled, following discovered paths")

    else:  # baseline
        cfg.adaptive_pgd = False
        cfg.order_mode = "uniform"
        logger.info("Baseline mode: uniform ordering, no discovery")

    # Get data loaders
    train_loader, val_loader, test_loader = get_dataloaders(cfg)

    # Create trainer
    trainer = LOATTrainer(cfg)

    # Train model
    logger.info("Starting training...")

    # Train based on mode
    if cfg.mode == "teacher":
        logger.info("Training teacher model with two-phase approach...")
        metrics = trainer.train_teacher_two_phase(train_loader, val_loader, test_loader)

        if 'val_robust_acc' in metrics and metrics['val_robust_acc']:
            logger.info(f"Teacher training complete with robust accuracy: {metrics['val_robust_acc'][-1]:.3f}")
        else:
            val_metrics = trainer.evaluate(val_loader)
            logger.info(f"Teacher training complete with robust accuracy: {val_metrics['robust']:.3f}")

    elif cfg.mode == "student":
        trainer.load_teacher_recipe(cfg.recipe_path)
        metrics = trainer.train_student_adaptive(train_loader, val_loader, test_loader)
        logger.info("Student training complete")

    else:  # baseline
        logger.info("Starting baseline training...")
        metrics = trainer.train(train_loader, val_loader, test_loader)

    # Final evaluation on test set
    logger.info("\nEvaluating on test set...")
    test_metrics = trainer.evaluate(test_loader)

    # Calculate final efficiency
    efficiency = test_metrics['robust'] / max(1, trainer.pgd_calls_train / 1e5)

    # Print results
    print("\n" + "="*60)
    print("FINAL RESULTS")
    print("="*60)
    print(f"Test Clean Accuracy:  {test_metrics['clean']:.4f}")
    print(f"Test Robust Accuracy: {test_metrics['robust']:.4f}")
    print(f"Training PGD Calls:   {trainer.pgd_calls_train:,}")
    print(f"Efficiency Score:     {efficiency:.4f}")
    print("="*60)

    # Save final results
    results = {
        'test_clean': test_metrics['clean'],
        'test_robust': test_metrics['robust'],
        'val_clean': metrics.get('val_clean', [None])[-1] if 'val_clean' in metrics and metrics['val_clean'] else None,
        'val_robust': metrics['val_robust'][-1] if 'val_robust' in metrics else None,
        'total_pgd_calls_train': trainer.pgd_calls_train,
        'total_pgd_calls_eval': trainer.pgd_calls_eval,
        'efficiency': efficiency,
        'config': asdict(cfg)
    }

    results_path = cfg.experiment_dir / 'final_results.json'
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2, default=str)

    logger.info(f"Results saved to {results_path}")

    return trainer


if __name__ == "__main__":
    trainer = main()